Skip to content

Conversation

svkeerthy
Copy link
Contributor

@svkeerthy svkeerthy commented Jun 20, 2025

This PR adds out-of-place arithmetic operators (+, -, *) to the Embedding class in IR2Vec, complementing the existing in-place operators (+=, -=, *=).

Tests have been added to verify the functionality of these new operators.

(Tracking issue - #141817)

Copy link
Contributor Author

svkeerthy commented Jun 20, 2025

@svkeerthy svkeerthy changed the title Overloading operator+ for Embeddngs [IR2Vec] Overloading operator+ for Embeddngs Jun 20, 2025
@svkeerthy svkeerthy changed the title [IR2Vec] Overloading operator+ for Embeddngs [IR2Vec] Overloading operator+ for Embeddings Jun 20, 2025
@svkeerthy svkeerthy changed the title [IR2Vec] Overloading operator+ for Embeddings [IR2Vec] Overloading operator+ for `Embeddings Jun 20, 2025
@svkeerthy svkeerthy changed the title [IR2Vec] Overloading operator+ for `Embeddings [IR2Vec] Overloading operator+ for Embeddings Jun 20, 2025
@svkeerthy svkeerthy marked this pull request as ready for review June 20, 2025 23:33
@llvmbot llvmbot added mlgo llvm:analysis Includes value tracking, cost tables and constant folding labels Jun 20, 2025
@llvmbot
Copy link
Member

llvmbot commented Jun 20, 2025

@llvm/pr-subscribers-mlgo

@llvm/pr-subscribers-llvm-analysis

Author: S. VenkataKeerthy (svkeerthy)

Changes

Add out-of-place addition operator for Embedding class in IR2Vec.

This is used in subsequent patches.


Full diff: https://github.com/llvm/llvm-project/pull/145118.diff

3 Files Affected:

  • (modified) llvm/include/llvm/Analysis/IR2Vec.h (+1)
  • (modified) llvm/lib/Analysis/IR2Vec.cpp (+8)
  • (modified) llvm/unittests/Analysis/IR2VecTest.cpp (+18)
diff --git a/llvm/include/llvm/Analysis/IR2Vec.h b/llvm/include/llvm/Analysis/IR2Vec.h
index 480b834077b86..f6c40d36f8026 100644
--- a/llvm/include/llvm/Analysis/IR2Vec.h
+++ b/llvm/include/llvm/Analysis/IR2Vec.h
@@ -106,6 +106,7 @@ struct Embedding {
   const std::vector<double> &getData() const { return Data; }
 
   /// Arithmetic operators
+  Embedding operator+(const Embedding &RHS) const;
   Embedding &operator+=(const Embedding &RHS);
   Embedding &operator-=(const Embedding &RHS);
   Embedding &operator*=(double Factor);
diff --git a/llvm/lib/Analysis/IR2Vec.cpp b/llvm/lib/Analysis/IR2Vec.cpp
index 27cc2a4109879..d5d27db8bd2bf 100644
--- a/llvm/lib/Analysis/IR2Vec.cpp
+++ b/llvm/lib/Analysis/IR2Vec.cpp
@@ -71,6 +71,14 @@ inline bool fromJSON(const llvm::json::Value &E, Embedding &Out,
 // Embedding
 //===----------------------------------------------------------------------===//
 
+Embedding Embedding::operator+(const Embedding &RHS) const {
+  assert(this->size() == RHS.size() && "Vectors must have the same dimension");
+  Embedding Result(*this);
+  std::transform(this->begin(), this->end(), RHS.begin(), Result.begin(),
+                 std::plus<double>());
+  return Result;
+}
+
 Embedding &Embedding::operator+=(const Embedding &RHS) {
   assert(this->size() == RHS.size() && "Vectors must have the same dimension");
   std::transform(this->begin(), this->end(), RHS.begin(), this->begin(),
diff --git a/llvm/unittests/Analysis/IR2VecTest.cpp b/llvm/unittests/Analysis/IR2VecTest.cpp
index 33ac16828eb6c..50eb7f73c6f50 100644
--- a/llvm/unittests/Analysis/IR2VecTest.cpp
+++ b/llvm/unittests/Analysis/IR2VecTest.cpp
@@ -109,6 +109,18 @@ TEST(EmbeddingTest, ConstructorsAndAccessors) {
   }
 }
 
+TEST(EmbeddingTest, AddVectorsOutOfPlace) {
+  Embedding E1 = {1.0, 2.0, 3.0};
+  Embedding E2 = {0.5, 1.5, -1.0};
+
+  Embedding E3 = E1 + E2;
+  EXPECT_THAT(E3, ElementsAre(1.5, 3.5, 2.0));
+
+  // Check that E1 and E2 are unchanged
+  EXPECT_THAT(E1, ElementsAre(1.0, 2.0, 3.0));
+  EXPECT_THAT(E2, ElementsAre(0.5, 1.5, -1.0));
+}
+
 TEST(EmbeddingTest, AddVectors) {
   Embedding E1 = {1.0, 2.0, 3.0};
   Embedding E2 = {0.5, 1.5, -1.0};
@@ -180,6 +192,12 @@ TEST(EmbeddingTest, AccessOutOfBounds) {
   EXPECT_DEATH(E[4] = 4.0, "Index out of bounds");
 }
 
+TEST(EmbeddingTest, MismatchedDimensionsAddVectorsOutOfPlace) {
+  Embedding E1 = {1.0, 2.0};
+  Embedding E2 = {1.0};
+  EXPECT_DEATH(E1 + E2, "Vectors must have the same dimension");
+}
+
 TEST(EmbeddingTest, MismatchedDimensionsAddVectors) {
   Embedding E1 = {1.0, 2.0};
   Embedding E2 = {1.0};

@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-20-increasing_tolerance_in_approximatelyequals branch from d05856c to bf89c59 Compare June 23, 2025 21:10
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-20-overloading_operator_for_embeddngs branch 2 times, most recently from 23de35c to 8345bbe Compare June 30, 2025 20:56
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-20-increasing_tolerance_in_approximatelyequals branch 2 times, most recently from 0472d10 to b2c203a Compare June 30, 2025 21:11
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-20-overloading_operator_for_embeddngs branch 2 times, most recently from 2846872 to 187a8fb Compare July 1, 2025 01:11
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-20-increasing_tolerance_in_approximatelyequals branch 2 times, most recently from 6cf6937 to ec1d9d6 Compare July 1, 2025 01:20
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-20-overloading_operator_for_embeddngs branch from 187a8fb to 14e7d5b Compare July 1, 2025 01:20
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-20-increasing_tolerance_in_approximatelyequals branch from ec1d9d6 to 7516580 Compare July 1, 2025 01:25
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-20-overloading_operator_for_embeddngs branch from 14e7d5b to 06d0a11 Compare July 1, 2025 01:26
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-20-increasing_tolerance_in_approximatelyequals branch from 7516580 to 4a31512 Compare July 1, 2025 17:45
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-20-overloading_operator_for_embeddngs branch from 06d0a11 to f1976fa Compare July 1, 2025 17:46
@svkeerthy svkeerthy changed the title [IR2Vec] Overloading operator+ for Embeddings [IR2Vec] Add out-of-place arithmetic operators to Embedding class Jul 1, 2025
@svkeerthy svkeerthy requested a review from mtrofin July 1, 2025 17:47
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-20-increasing_tolerance_in_approximatelyequals branch from 4a31512 to 6cbae82 Compare July 1, 2025 18:53
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-20-overloading_operator_for_embeddngs branch from f1976fa to 10019ca Compare July 1, 2025 18:53
@svkeerthy svkeerthy requested a review from mtrofin July 1, 2025 18:55
Base automatically changed from users/svkeerthy/06-20-increasing_tolerance_in_approximatelyequals to main July 1, 2025 19:03
@svkeerthy svkeerthy force-pushed the users/svkeerthy/06-20-overloading_operator_for_embeddngs branch from 10019ca to c45d8a0 Compare July 1, 2025 19:05
Copy link
Contributor Author

svkeerthy commented Jul 1, 2025

Merge activity

  • Jul 1, 7:09 PM UTC: A user started a stack merge that includes this pull request via Graphite.
  • Jul 1, 7:09 PM UTC: @svkeerthy merged this pull request with Graphite.

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
llvm:analysis Includes value tracking, cost tables and constant folding mlgo
Projects
None yet
Development

Successfully merging this pull request may close these issues.

3 participants