Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
78 changes: 78 additions & 0 deletions src/binaryArith.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -818,6 +818,84 @@ void multTwoNumbers(CtPtrs& product, const CtPtrs& lhs, const CtPtrs& rhs,
addManyNumbers(product, nums, resSize, unpackSlotEncoding);
}


// Square number (i.e. an array of bits) a.
// Computes the pairwise products x_{i,j} = a_i * a_j simplified
// In case of a_i * a_j where i==j we get a_i * aj = ai = aj
// then sums the prodcuts using the 3-for-2 method.
// multiplication level reduced from size(a) to size(a) - 1
void square(CtPtrs &product, const CtPtrs &a,long sizeLimit,
std::vector<zzX>* unpackSlotEncoding) {
long aSize = lsize(a);
long resSize = 2 * aSize;
if (sizeLimit>0 && sizeLimit<resSize) resSize=sizeLimit;

if (a.numNonNull() < 1) {
setLengthZero(product);
return; // return 0
}

// Edge case, if a or b is 1 bit
if (aSize == 1) {
if (a[0]->isEmpty()) {
setLengthZero(product);
return;
}
vecCopy(product, a, aSize);
return;
}

// We make sure aa is the larger of the two integers
// to keep the number of additions to a minimum

NTL::Vec<NTL::Vec<Ctxt> > numbers(NTL::INIT_SIZE, aSize - 1);
const Ctxt *ct_ptr = a.ptr2nonNull();
long nNums = lsize(numbers);
for (long i = 0; i < nNums; i++)
numbers[i].SetLength(i + aSize - (((nNums - i) == 2) ? 0 : 1),
Ctxt(ZeroCtxtLike, *ct_ptr));

std::vector<std::pair<long, long> > pairs;

for (long i = 0; i < (nNums + 1); i++) {
for (long j = std::max(i, long(1)); j < (nNums + 1); j++) {
if (a.isSet(j) && !(a[j]->isEmpty()) &&
a.isSet(i) && !(a[i]->isEmpty()))
pairs.push_back(std::pair<long, long>(i, j));
}
}
long nPairs = lsize(pairs);
NTL_EXEC_RANGE(nPairs, first, last)
for (long idx = first; idx < last; idx++) {
long i, j;
std::tie(i, j) = pairs[idx];
if (j == i) {
if (i == nNums) {
numbers[i - 2][i + j - 2] = *(a[i]);
} else {
numbers[i][2 * i - 2] = *(a[i]);
}
} else {
numbers[i][i + j - 1] = *(a[i]);
numbers[i][i + j - 1].multiplyBy(*(a[j])); // multiply by the bit of b
}
}
NTL_EXEC_RANGE_END

const Ctxt* ctptr = a.ptr2nonNull();
{
NTL::Vec<Ctxt> prod;
CtPtrMat_VecCt nums(numbers); // A wrapper around numbers
CtPtrs_VecCt prod2(prod);
addManyNumbers(prod2, nums, resSize - 2,unpackSlotEncoding);
resize(product, resSize, Ctxt(ZeroCtxtLike, *ct_ptr));
*product[0] = *a[0];
*product[1] = Ctxt(ZeroCtxtLike,*ctptr);
for (size_t i = 0; i < prod.length(); i++)
*product[i + 2] = prod[i];
}
}

/* seven4Three: adding seven input bits, getting a 3-bit counter
*
* input: in[6..0]
Expand Down
10 changes: 10 additions & 0 deletions src/binaryArith.h
Original file line number Diff line number Diff line change
Expand Up @@ -71,6 +71,16 @@ void multTwoNumbers(CtPtrs& product, const CtPtrs& lhs, const CtPtrs& rhs,
bool rhsTwosComplement=false, long sizeLimit=0,
std::vector<zzX>* unpackSlotEncoding=nullptr);

/**
* @brief Square one number in binary representation where each ciphertext of the input vector contains a bit.
* @param product result of the squaring operation.
* @param a number to be squared
* @param sizeLimit number of bits to compute on, taken from the least significant end.
* @param unpackSlotEncoding vector of constants for unpacking, as used in bootstrapping.
**/
void square(CtPtrs &product, const CtPtrs &a,long sizeLimit=0,
std::vector<zzX>* unpackSlotEncoding =nullptr);

/**
* @brief Decrypt the binary numbers that are encrypted in eNums.
* @param pNums vector to decrypt the binary numbers into.
Expand Down