Skip to content

Commit 00d7fd6

Browse files
authored
Fboemer/more ciphertext async apis (#255)
* Async ciphertext APIs for rotation & mod-switching * Add ciphertext async APIs for inner product & sum
1 parent da2e56e commit 00d7fd6

File tree

11 files changed

+388
-133
lines changed

11 files changed

+388
-133
lines changed

Package.resolved

Lines changed: 1 addition & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

Package.swift

Lines changed: 1 addition & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -102,6 +102,7 @@ let package = Package(
102102
.target(
103103
name: "HomomorphicEncryption",
104104
dependencies: [
105+
.product(name: "AsyncAlgorithms", package: "swift-async-algorithms"),
105106
.product(name: "Crypto", package: "swift-crypto"),
106107
.product(name: "_CryptoExtras", package: "swift-crypto"),
107108
"CUtil",

Sources/HomomorphicEncryption/Ciphertext.swift

Lines changed: 179 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -12,6 +12,8 @@
1212
// See the License for the specific language governing permissions and
1313
// limitations under the License.
1414

15+
public import AsyncAlgorithms
16+
1517
/// Ciphertext type.
1618
public struct Ciphertext<Scheme: HeScheme, Format: PolyFormat>: Equatable, Sendable {
1719
public typealias Scalar = Scheme.Scalar
@@ -28,7 +30,7 @@ public struct Ciphertext<Scheme: HeScheme, Format: PolyFormat>: Equatable, Senda
2830
///
2931
/// After a fresh encryption, the ciphertext has ``HeScheme/freshCiphertextPolyCount`` polynomials.
3032
/// The count may change during the course of HE operations, e.g. increase during ciphertext multiplication,
31-
/// or decrease during relinearization ``Ciphertext/relinearize(using:)``.
33+
/// or decrease during relinearization ``Ciphertext/relinearize(using:)-41bsm``.
3234
public var polyCount: Int {
3335
polys.count
3436
}
@@ -314,7 +316,7 @@ public struct Ciphertext<Scheme: HeScheme, Format: PolyFormat>: Equatable, Senda
314316
///
315317
/// If the ciphertext already has a single modulus, this is a no-op.
316318
/// - Throws: Error upon failure to modulus switch.
317-
/// - seealso: ``Ciphertext/modSwitchDown()`` for more information and an alternative API.
319+
/// - seealso: ``Ciphertext/modSwitchDown()-4an2b`` for more information and an alternative API.
318320
@inlinable
319321
public mutating func modSwitchDownToSingle() throws where Format == Scheme.CanonicalCiphertextFormat {
320322
try Scheme.modSwitchDownToSingle(&self)
@@ -535,22 +537,31 @@ extension Ciphertext where Format == Scheme.CanonicalCiphertextFormat {
535537
}
536538

537539
extension Collection {
540+
/// Sums together the ciphertexts in the collection.
541+
/// - Throws: Precondition failure if the collection is empty.
542+
/// - Returns: The sum.
538543
@inlinable
539-
func sum<Scheme>() throws -> Element where Element == Ciphertext<Scheme, Eval> {
544+
public func sum<Scheme>() throws -> Element where Element == Ciphertext<Scheme, Eval> {
540545
precondition(!isEmpty)
541546
// swiftlint:disable:next force_unwrapping
542547
return try dropFirst().reduce(first!) { try $0 + $1 }
543548
}
544549

550+
/// Sums together the ciphertexts in the collection.
551+
/// - Throws: Precondition failure if the collection is empty.
552+
/// - Returns: The sum.
545553
@inlinable
546-
func sum<Scheme>() throws -> Element where Element == Ciphertext<Scheme, Coeff> {
554+
public func sum<Scheme>() throws -> Element where Element == Ciphertext<Scheme, Coeff> {
547555
precondition(!isEmpty)
548556
// swiftlint:disable:next force_unwrapping
549557
return try dropFirst().reduce(first!) { try $0 + $1 }
550558
}
551559

560+
/// Sums together the ciphertexts in the collection.
561+
/// - Throws: Precondition failure if the collection is empty.
562+
/// - Returns: The sum.
552563
@inlinable
553-
func sum<Scheme>() throws -> Element where Element == Ciphertext<Scheme, Scheme.CanonicalCiphertextFormat> {
564+
public func sum<Scheme>() throws -> Element where Element == Ciphertext<Scheme, Scheme.CanonicalCiphertextFormat> {
554565
precondition(!isEmpty)
555566
// swiftlint:disable:next force_unwrapping
556567
return try dropFirst().reduce(first!) { try $0 + $1 }
@@ -937,4 +948,167 @@ extension Ciphertext {
937948
}
938949
throw HeError.errorCastingPolyFormat(from: Format.self, to: Scheme.CanonicalCiphertextFormat.self)
939950
}
951+
952+
// MARK: Async rotations
953+
954+
/// Asynchronously rotates the columns of a ciphertext.
955+
///
956+
/// - Parameters:
957+
/// - step: Number of slots to rotate. Negative values indicate a left rotation, and positive values indicate a
958+
/// right rotation. Must have absolute value in `[1, N / 2 - 1]` where `N` is the RLWE ring dimension, given by
959+
/// ``EncryptionParameters/polyDegree``.
960+
/// - evaluationKey: Evaluation key to use in the HE computation. Must contain the Galois element associated with
961+
/// `step`, see ``GaloisElement/rotatingColumns(by:degree:)``.
962+
/// - Throws: failure to rotate ciphertext's columns.
963+
/// - seealso: ``HeScheme/rotateColumns(of:by:using:)-7h3fz`` for an alternate API and more information.
964+
@inlinable
965+
public mutating func rotateColumns(by step: Int,
966+
using evaluationKey: EvaluationKey<Scheme>) async throws
967+
where Format == Scheme.CanonicalCiphertextFormat
968+
{
969+
try await Scheme.rotateColumnsAsync(of: &self, by: step, using: evaluationKey)
970+
}
971+
972+
/// Asynchronously swaps the rows of a ciphertext.
973+
///
974+
/// A plaintext in ``EncodeFormat/simd`` format can be viewed a `2 x (N / 2)` matrix of coefficients.
975+
/// For instance, for `N = 8`, given a ciphertext encrypting a plaintext with values
976+
/// ```
977+
/// [1, 2, 3, 4, 5, 6, 7, 8]
978+
/// ```
979+
/// calling ``HeScheme/swapRows(of:using:)`` with `step: 1` will yield a ciphertext decrypting to
980+
/// ```
981+
/// [5, 6, 7, 8, 1, 2, 3, 4]
982+
/// ```
983+
/// - Parameter evaluationKey: Evaluation key to use in the HE computation. Must contain the Galois element
984+
/// associated with `step`, see ``GaloisElement/rotatingColumns(by:degree:)``.
985+
/// - Throws: error upon failure to swap the ciphertext's rows.
986+
/// - seealso: ``HeScheme/swapRows(of:using:)-50tac`` for an alternate API.
987+
@inlinable
988+
public mutating func swapRows(using evaluationKey: EvaluationKey<Scheme>) async throws
989+
where Format == Scheme.CanonicalCiphertextFormat
990+
{
991+
try await Scheme.swapRowsAsync(of: &self, using: evaluationKey)
992+
}
993+
994+
/// Asynchronously performs modulus switching on the ciphertext.
995+
///
996+
/// - Throws: Error upon failure to mod-switch.
997+
/// - seealso: ``HeScheme/modSwitchDown(_:)`` for an alternative API and more information.
998+
@inlinable
999+
public mutating func modSwitchDown() async throws where Format == Scheme.CanonicalCiphertextFormat {
1000+
try await Scheme.modSwitchDownAsync(&self)
1001+
}
1002+
1003+
/// Asynchronously performs modulus switching to a single modulus.
1004+
///
1005+
/// If the ciphertext already has a single modulus, this is a no-op.
1006+
/// - Throws: Error upon failure to modulus switch.
1007+
/// - seealso: ``Ciphertext/modSwitchDown()-4an2b`` for more information and an alternative API.
1008+
@inlinable
1009+
public mutating func modSwitchDownToSingle() async throws where Format == Scheme.CanonicalCiphertextFormat {
1010+
try await Scheme.modSwitchDownToSingleAsync(&self)
1011+
}
1012+
}
1013+
1014+
extension Ciphertext where Format == Scheme.CanonicalCiphertextFormat {
1015+
/// Asynchronously applies a Galois transformation.
1016+
///
1017+
/// - Parameters:
1018+
/// - element: Galois element of the transformation. Must be odd in `[1, 2 * N - 1]` where `N` is the RLWE ring
1019+
/// dimension, given by ``EncryptionParameters/polyDegree``.
1020+
/// - key: Evaluation key. Must contain Galois element `element`.
1021+
/// - Throws: Error upon failure to apply the Galois transformation.
1022+
/// - seealso: ``HeScheme/applyGalois(ciphertext:element:using:)`` for an alternative API and more information.
1023+
@inlinable
1024+
public mutating func applyGalois(element: Int, using key: EvaluationKey<Scheme>) async throws {
1025+
try await Scheme.applyGaloisAsync(ciphertext: &self, element: element, using: key)
1026+
}
1027+
1028+
/// Asynchronously Relinearizes the ciphertext.
1029+
///
1030+
/// - Parameter key: Evaluation key to relinearize with. Must contain a `RelinearizationKey`.
1031+
/// - Throws: Error upon failure to relinearize.
1032+
/// - seealso: ``HeScheme/relinearize(_:using:)`` for an alternative API and more information.
1033+
@inlinable
1034+
public mutating func relinearize(using key: EvaluationKey<Scheme>) async throws {
1035+
try await Scheme.relinearizeAsync(&self, using: key)
1036+
}
1037+
}
1038+
1039+
// MARK: - Async collection extensions
1040+
1041+
extension Collection {
1042+
/// Sums together the ciphertexts in the collection.
1043+
/// - Throws: Precondition failure if the collection is empty.
1044+
/// - Returns: The sum.
1045+
@inlinable
1046+
public func sum<Scheme>() async throws -> Element where Element == Ciphertext<Scheme, Eval> {
1047+
precondition(!isEmpty)
1048+
// swiftlint:disable:next force_unwrapping
1049+
return try await dropFirst().async.reduce(first!) { try await $0 + $1 }
1050+
}
1051+
1052+
/// Sums together the ciphertexts in the collection.
1053+
/// - Throws: Precondition failure if the collection is empty.
1054+
/// - Returns: The sum.
1055+
@inlinable
1056+
public func sum<Scheme>() async throws -> Element where Element == Ciphertext<Scheme, Coeff> {
1057+
precondition(!isEmpty)
1058+
// swiftlint:disable:next force_unwrapping
1059+
return try await dropFirst().async.reduce(first!) { try await $0 + $1 }
1060+
}
1061+
1062+
/// Sums together the ciphertexts in the collection.
1063+
/// - Throws: Precondition failure if the collection is empty.
1064+
/// - Returns: The sum.
1065+
@inlinable
1066+
public func sum<Scheme>() async throws -> Element where
1067+
Element == Ciphertext<Scheme, Scheme.CanonicalCiphertextFormat>
1068+
{
1069+
precondition(!isEmpty)
1070+
// swiftlint:disable:next force_unwrapping
1071+
return try await dropFirst().async.reduce(first!) { try await $0 + $1 }
1072+
}
1073+
1074+
/// Asynchronously computes an inner product between self and a collection of (optional) plaintexts in ``Eval``
1075+
/// format.
1076+
///
1077+
/// The inner product encrypts `sum_{i, plaintexts[i] != nil} self[i] * plaintexts[i]`. `plaintexts[i]`
1078+
/// may be `nil`, which denotes a zero plaintext.
1079+
/// - Parameter plaintexts: Plaintexts. Must not be empty and have `count` matching `self.count`.
1080+
/// - Returns: A ciphertext encrypting the inner product.
1081+
/// - Throws: Error upon failure to compute inner product.
1082+
@inlinable
1083+
public func innerProduct<Scheme>(plaintexts: some Collection<Plaintext<Scheme, Eval>?>) async throws -> Element
1084+
where Element == Ciphertext<Scheme, Eval>
1085+
{
1086+
try await Scheme.innerProductAsync(ciphertexts: self, plaintexts: plaintexts)
1087+
}
1088+
1089+
/// Asynchronously computes an inner product between self and a collection of plaintexts in ``Eval`` format.
1090+
///
1091+
/// The inner product encrypts `sum_{i} self[i] * plaintexts[i]`.
1092+
/// - Parameter plaintexts: Plaintexts. Must not be empty and have `count` matching `self.count`.
1093+
/// - Returns: A ciphertext encrypting the inner product.
1094+
/// - Throws: Error upon failure to compute inner product.
1095+
@inlinable
1096+
public func innerProduct<Scheme>(plaintexts: some Collection<Plaintext<Scheme, Eval>>) async throws -> Element
1097+
where Element == Ciphertext<Scheme, Eval>
1098+
{
1099+
try await Scheme.innerProductAsync(ciphertexts: self, plaintexts: plaintexts)
1100+
}
1101+
1102+
/// Asynchronously computes an inner product between self and another collection of ciphertexts.
1103+
///
1104+
/// The inner product encrypts `sum_{i} self[i] * ciphertexts[i]`.
1105+
/// - Parameter ciphertexts: Ciphertexts. Must not be empty and have `count` matching `self.count`.
1106+
/// - Returns: A ciphertext encrypting the inner product.
1107+
/// - Throws: Error upon failure to compute inner product.
1108+
@inlinable
1109+
public func innerProduct<Scheme>(ciphertexts: some Collection<Element>) async throws -> Element
1110+
where Element == Ciphertext<Scheme, Scheme.CanonicalCiphertextFormat>
1111+
{
1112+
try await Scheme.innerProductAsync(self, ciphertexts)
1113+
}
9401114
}

Sources/HomomorphicEncryption/HeScheme.swift

Lines changed: 4 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -485,7 +485,7 @@ public protocol HeScheme: Sendable {
485485
/// - evaluationKey: Evaluation key to use in the HE computation. Must contain the Galois element associated with
486486
/// `step`, see ``GaloisElement/rotatingColumns(by:degree:)``.
487487
/// - Throws: failure to rotate ciphertext's columns.
488-
/// - seealso: ``Ciphertext/rotateColumns(by:using:)`` for an alternate API.
488+
/// - seealso: ``Ciphertext/rotateColumns(by:using:)-4f3tp`` for an alternate API.
489489
static func rotateColumns(
490490
of ciphertext: inout CanonicalCiphertext,
491491
by step: Int,
@@ -513,7 +513,7 @@ public protocol HeScheme: Sendable {
513513
/// - evaluationKey: Evaluation key to use in the HE computation. Must contain the Galois element returned from
514514
/// ``GaloisElement/swappingRows(degree:)``.
515515
/// - Throws: error upon failure to swap the ciphertext's rows.
516-
/// - seealso: ``Ciphertext/swapRows(using:)`` for an alternate API. ``swapRowsAsync(of:using:)`` for an async
516+
/// - seealso: ``Ciphertext/swapRows(using:)-4o179`` for an alternate API. ``swapRowsAsync(of:using:)`` for an async
517517
/// version of this API
518518
static func swapRows(of ciphertext: inout CanonicalCiphertext, using evaluationKey: EvaluationKey) throws
519519

@@ -862,7 +862,7 @@ public protocol HeScheme: Sendable {
862862
/// serialization and sending the ciphertext to the secret key owner.
863863
/// - Parameter ciphertext: Ciphertext; must have > 1 ciphertext modulus.
864864
/// - Throws: Error upon failure to mod-switch.
865-
/// - seealso: ``Ciphertext/modSwitchDown()`` for an alternative API.
865+
/// - seealso: ``Ciphertext/modSwitchDown()-4an2b`` for an alternative API.
866866
/// - seealso: ``modSwitchDownAsync(_:)`` for an async version of this API
867867
static func modSwitchDown(_ ciphertext: inout CanonicalCiphertext) throws
868868

@@ -873,7 +873,7 @@ public protocol HeScheme: Sendable {
873873
///
874874
/// If the ciphertext already has a single modulus, this is a no-op.
875875
/// - Throws: Error upon failure to modulus switch.
876-
/// - seealso: ``Ciphertext/modSwitchDownToSingle()`` for more information and an alternative API.
876+
/// - seealso: ``Ciphertext/modSwitchDownToSingle()-3x0dy`` for more information and an alternative API.
877877
/// - seealso: ``modSwitchDownToSingleAsync(_:)`` for an async version of this API
878878
static func modSwitchDownToSingle(_ ciphertext: inout CanonicalCiphertext) throws
879879

Sources/PrivateInformationRetrieval/IndexPir/MulPir.swift

Lines changed: 6 additions & 10 deletions
Original file line numberDiff line numberDiff line change
@@ -359,10 +359,7 @@ extension MulPirServer {
359359
let startIndex = dataChunk.startIndex + expandedDim0Query.count * columnIndex
360360
let endIndex = min(startIndex + expandedDim0Query.count, dataChunk.endIndex)
361361
let plaintexts = dataChunk[startIndex..<endIndex]
362-
return try await Scheme.innerProductAsync(
363-
ciphertexts: expandedDim0Query,
364-
plaintexts: plaintexts)
365-
.convertToCanonicalFormat()
362+
return try await expandedDim0Query.innerProduct(plaintexts: plaintexts).convertToCanonicalFormat()
366363
})
367364

368365
var queryStartingIndex = expandedRemainingQuery.startIndex
@@ -373,17 +370,16 @@ extension MulPirServer {
373370
.async.map { startIndex in
374371
let vector0 = expandedRemainingQuery[currentIndex..<currentIndex + dimensionSize]
375372
let vector1 = currentResults[startIndex..<startIndex + dimensionSize]
376-
var product = try await Scheme.innerProductAsync(vector0, vector1)
377-
try await Scheme.relinearizeAsync(&product, using: evaluationKey)
373+
var product = try await vector0.innerProduct(ciphertexts: vector1)
374+
try await product.relinearize(using: evaluationKey)
378375
return product
379376
})
380377
queryStartingIndex += dimensionSize
381378
}
382379

383-
precondition(
384-
intermediateResults.count == 1,
385-
"There should be only 1 ciphertext in the final result for each chunk")
386-
try await Scheme.modSwitchDownToSingleAsync(&intermediateResults[0])
380+
precondition(intermediateResults.count == 1,
381+
"There should be only 1 ciphertext in the final result for each chunk")
382+
try await intermediateResults[0].modSwitchDownToSingle()
387383
return try await intermediateResults[0].convertToCoeffFormat()
388384
}
389385

Sources/PrivateInformationRetrieval/IndexPir/PirUtil.swift

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -90,7 +90,7 @@ extension PirUtilProtocol {
9090
let applyGaloisCount = 1 << ((targetElement - 1).log2 - (galoisElement - 1).log2)
9191
var currElement = 1
9292
for await _ in (0..<applyGaloisCount).async {
93-
try await Scheme.applyGaloisAsync(ciphertext: &c1, element: galoisElement, using: evaluationKey)
93+
try await c1.applyGalois(element: galoisElement, using: evaluationKey)
9494
currElement *= galoisElement
9595
currElement %= (2 * degree)
9696
}

Sources/PrivateNearestNeighborSearch/CiphertextMatrix.swift

Lines changed: 4 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -199,8 +199,8 @@ extension CiphertextMatrix {
199199
/// - Throws: Error upon failure to modulus switch.
200200
@inlinable
201201
public mutating func modSwitchDownToSingle() async throws where Format == Scheme.CanonicalCiphertextFormat {
202-
for index in 0..<ciphertexts.count {
203-
try await Scheme.modSwitchDownToSingleAsync(&ciphertexts[index])
202+
for index in ciphertexts.indices {
203+
try await ciphertexts[index].modSwitchDownToSingle()
204204
}
205205
}
206206
}
@@ -337,10 +337,7 @@ extension CiphertextMatrix {
337337
let rotateCount = simdColumnCount / (copiesInMask * columnCountPowerOfTwo) - 1
338338
var ciphertextCopyRight = ciphertext
339339
for await _ in (0..<rotateCount).async {
340-
try await Scheme.rotateColumnsAsync(
341-
of: &ciphertextCopyRight,
342-
by: columnCountPowerOfTwo,
343-
using: evaluationKey)
340+
try await ciphertextCopyRight.rotateColumns(by: columnCountPowerOfTwo, using: evaluationKey)
344341
try await ciphertext += ciphertextCopyRight
345342
}
346343
// e.g., `ciphertext` now encrypts
@@ -349,7 +346,7 @@ extension CiphertextMatrix {
349346

350347
// Duplicate values to both SIMD rows
351348
var ciphertextCopy = ciphertext
352-
try await Scheme.swapRowsAsync(of: &ciphertextCopy, using: evaluationKey)
349+
try await ciphertextCopy.swapRows(using: evaluationKey)
353350
try await ciphertext += ciphertextCopy
354351
// e.g., `ciphertext` now encrypts
355352
// [[3, 4, 3, 4, 3, 4, 3, 4],

Sources/PrivateNearestNeighborSearch/MatrixMultiplication.swift

Lines changed: 2 additions & 5 deletions
Original file line numberDiff line numberDiff line change
@@ -176,7 +176,7 @@ extension PlaintextMatrix {
176176
for step in 0..<babyStepGiantStep.babyStep {
177177
rotatedStates.append(state)
178178
if step != babyStepGiantStep.babyStep - 1 {
179-
try await Scheme.rotateColumnsAsync(of: &state, by: -1, using: evaluationKey)
179+
try await state.rotateColumns(by: -1, using: evaluationKey)
180180
}
181181
}
182182
let rotatedCiphertexts: [Scheme.EvalCiphertext] = try await .init(
@@ -201,10 +201,7 @@ extension PlaintextMatrix {
201201
let ciphertexts = rotatedCiphertexts[0..<plaintextRows.count]
202202

203203
// 2) Compute w_k
204-
let innerProduct =
205-
try await Scheme.innerProductAsync(
206-
ciphertexts: ciphertexts,
207-
plaintexts: plaintextRows)
204+
let innerProduct = try await ciphertexts.innerProduct(plaintexts: plaintextRows)
208205
return try await innerProduct.convertToCanonicalFormat()
209206
}
210207

0 commit comments

Comments
 (0)