diff --git a/src/binaryArith.cpp b/src/binaryArith.cpp index 1f55c2662..ec4e31cfc 100644 --- a/src/binaryArith.cpp +++ b/src/binaryArith.cpp @@ -818,6 +818,84 @@ void multTwoNumbers(CtPtrs& product, const CtPtrs& lhs, const CtPtrs& rhs, addManyNumbers(product, nums, resSize, unpackSlotEncoding); } + +// Square number (i.e. an array of bits) a. +// Computes the pairwise products x_{i,j} = a_i * a_j simplified +// In case of a_i * a_j where i==j we get a_i * aj = ai = aj +// then sums the prodcuts using the 3-for-2 method. +// multiplication level reduced from size(a) to size(a) - 1 +void square(CtPtrs &product, const CtPtrs &a,long sizeLimit, + std::vector* unpackSlotEncoding) { + long aSize = lsize(a); + long resSize = 2 * aSize; + if (sizeLimit>0 && sizeLimitisEmpty()) { + setLengthZero(product); + return; + } + vecCopy(product, a, aSize); + return; + } + + // We make sure aa is the larger of the two integers + // to keep the number of additions to a minimum + + NTL::Vec > numbers(NTL::INIT_SIZE, aSize - 1); + const Ctxt *ct_ptr = a.ptr2nonNull(); + long nNums = lsize(numbers); + for (long i = 0; i < nNums; i++) + numbers[i].SetLength(i + aSize - (((nNums - i) == 2) ? 0 : 1), + Ctxt(ZeroCtxtLike, *ct_ptr)); + + std::vector > pairs; + + for (long i = 0; i < (nNums + 1); i++) { + for (long j = std::max(i, long(1)); j < (nNums + 1); j++) { + if (a.isSet(j) && !(a[j]->isEmpty()) && + a.isSet(i) && !(a[i]->isEmpty())) + pairs.push_back(std::pair(i, j)); + } + } + long nPairs = lsize(pairs); + NTL_EXEC_RANGE(nPairs, first, last) + for (long idx = first; idx < last; idx++) { + long i, j; + std::tie(i, j) = pairs[idx]; + if (j == i) { + if (i == nNums) { + numbers[i - 2][i + j - 2] = *(a[i]); + } else { + numbers[i][2 * i - 2] = *(a[i]); + } + } else { + numbers[i][i + j - 1] = *(a[i]); + numbers[i][i + j - 1].multiplyBy(*(a[j])); // multiply by the bit of b + } + } + NTL_EXEC_RANGE_END + + const Ctxt* ctptr = a.ptr2nonNull(); + { + NTL::Vec prod; + CtPtrMat_VecCt nums(numbers); // A wrapper around numbers + CtPtrs_VecCt prod2(prod); + addManyNumbers(prod2, nums, resSize - 2,unpackSlotEncoding); + resize(product, resSize, Ctxt(ZeroCtxtLike, *ct_ptr)); + *product[0] = *a[0]; + *product[1] = Ctxt(ZeroCtxtLike,*ctptr); + for (size_t i = 0; i < prod.length(); i++) + *product[i + 2] = prod[i]; + } +} + /* seven4Three: adding seven input bits, getting a 3-bit counter * * input: in[6..0] diff --git a/src/binaryArith.h b/src/binaryArith.h index df80fd548..d8eb79019 100644 --- a/src/binaryArith.h +++ b/src/binaryArith.h @@ -71,6 +71,16 @@ void multTwoNumbers(CtPtrs& product, const CtPtrs& lhs, const CtPtrs& rhs, bool rhsTwosComplement=false, long sizeLimit=0, std::vector* unpackSlotEncoding=nullptr); +/** + * @brief Square one number in binary representation where each ciphertext of the input vector contains a bit. + * @param product result of the squaring operation. + * @param a number to be squared + * @param sizeLimit number of bits to compute on, taken from the least significant end. + * @param unpackSlotEncoding vector of constants for unpacking, as used in bootstrapping. + **/ +void square(CtPtrs &product, const CtPtrs &a,long sizeLimit=0, + std::vector* unpackSlotEncoding =nullptr); + /** * @brief Decrypt the binary numbers that are encrypted in eNums. * @param pNums vector to decrypt the binary numbers into.