diff --git a/components/usage/pkg/apiv1/usage.go b/components/usage/pkg/apiv1/usage.go index c0dcf2a6fa34cd..4de60ccadaa46c 100644 --- a/components/usage/pkg/apiv1/usage.go +++ b/components/usage/pkg/apiv1/usage.go @@ -236,6 +236,19 @@ func (s *UsageService) SetCostCenter(ctx context.Context, in *v1.SetCostCenterRe }, nil } +func (s UsageService) ResetUsage(ctx context.Context, req *v1.ResetUsageRequest) (*v1.ResetUsageResponse, error) { + now := time.Now() + costCentersToUpdate, err := s.costCenterManager.ListLatestCostCentersWithBillingTimeBefore(ctx, db.CostCenter_Other, now) + if err != nil { + log.WithError(err).Error("Failed to list cost centers to update.") + return nil, status.Errorf(codes.Internal, "Failed to identify expired cost centers for Other billing strategy") + } + + log.Infof("Identified %d expired cost centers at relative to %s", len(costCentersToUpdate), now.Format(time.RFC3339)) + + return &v1.ResetUsageResponse{}, nil +} + func (s *UsageService) ReconcileUsage(ctx context.Context, req *v1.ReconcileUsageRequest) (*v1.ReconcileUsageResponse, error) { from := req.GetFrom().AsTime() to := req.GetTo().AsTime() diff --git a/components/usage/pkg/db/cost_center.go b/components/usage/pkg/db/cost_center.go index fb4fabb27e798c..4a78c1a2601fe1 100644 --- a/components/usage/pkg/db/cost_center.go +++ b/components/usage/pkg/db/cost_center.go @@ -32,7 +32,7 @@ type CostCenter struct { SpendingLimit int32 `gorm:"column:spendingLimit;type:int;default:0;" json:"spendingLimit"` BillingStrategy BillingStrategy `gorm:"column:billingStrategy;type:varchar;size:255;" json:"billingStrategy"` NextBillingTime VarcharTime `gorm:"column:nextBillingTime;type:varchar;size:255;" json:"nextBillingTime"` - LastModified time.Time `gorm:"->:column:_lastModified;type:timestamp;default:CURRENT_TIMESTAMP(6);" json:"_lastModified"` + LastModified time.Time `gorm:"->;column:_lastModified;type:timestamp;default:CURRENT_TIMESTAMP(6);" json:"_lastModified"` } // TableName sets the insert table name for this struct type @@ -233,3 +233,33 @@ func (c *CostCenterManager) ComputeInvoiceUsageRecord(ctx context.Context, attri Draft: false, }, nil } + +func (c *CostCenterManager) ListLatestCostCentersWithBillingTimeBefore(ctx context.Context, strategy BillingStrategy, billingTimeBefore time.Time) ([]CostCenter, error) { + db := c.conn.WithContext(ctx) + + var results []CostCenter + var batch []CostCenter + + subquery := db. + Table((&CostCenter{}).TableName()). + // Retrieve the latest CostCenter for a given (attribution) ID. + Select("DISTINCT id, MAX(creationTime) AS creationTime"). + Group("id") + tx := db. + Table(fmt.Sprintf("%s as cc", (&CostCenter{}).TableName())). + // Join on our set of latest CostCenter records + Joins("INNER JOIN (?) AS expiredCC on cc.id = expiredCC.id AND cc.creationTime = expiredCC.creationTime", subquery). + Where("cc.billingStrategy = ?", strategy). + Where("nextBillingTime != ?", ""). + Where("nextBillingTime < ?", TimeToISO8601(billingTimeBefore)). + FindInBatches(&batch, 1000, func(tx *gorm.DB, iteration int) error { + results = append(results, batch...) + return nil + }) + + if tx.Error != nil { + return nil, fmt.Errorf("failed to list cost centers with billing time before: %w", tx.Error) + } + + return results, nil +} diff --git a/components/usage/pkg/db/cost_center_test.go b/components/usage/pkg/db/cost_center_test.go index d94af9ee80e04e..80c7cf5fed3183 100644 --- a/components/usage/pkg/db/cost_center_test.go +++ b/components/usage/pkg/db/cost_center_test.go @@ -244,3 +244,92 @@ func requireCostCenterEqual(t *testing.T, expected, actual db.CostCenter) { require.EqualValues(t, expected.SpendingLimit, actual.SpendingLimit) require.Equal(t, expected.BillingStrategy, actual.BillingStrategy) } + +func TestCostCenter_ListLatestCostCentersWithBillingTimeBefore(t *testing.T) { + + t.Run("no cost centers found when no data exists", func(t *testing.T) { + conn := dbtest.ConnectForTests(t) + mnr := db.NewCostCenterManager(conn, db.DefaultSpendingLimit{ + ForTeams: 0, + ForUsers: 500, + }) + + ts := time.Date(2022, 10, 10, 10, 10, 10, 10, time.UTC) + + retrieved, err := mnr.ListLatestCostCentersWithBillingTimeBefore(context.Background(), db.CostCenter_Other, ts.Add(7*24*time.Hour)) + require.NoError(t, err) + require.Len(t, retrieved, 0) + }) + + t.Run("returns the most recent cost center (by creation time)", func(t *testing.T) { + conn := dbtest.ConnectForTests(t) + mnr := db.NewCostCenterManager(conn, db.DefaultSpendingLimit{ + ForTeams: 0, + ForUsers: 500, + }) + + attributionID := uuid.New().String() + firstCreation := time.Date(2022, 10, 10, 10, 10, 10, 10, time.UTC) + secondCreation := firstCreation.Add(24 * time.Hour) + + costCenters := []db.CostCenter{ + dbtest.NewCostCenter(t, db.CostCenter{ + ID: db.NewTeamAttributionID(attributionID), + SpendingLimit: 100, + CreationTime: db.NewVarcharTime(firstCreation), + BillingStrategy: db.CostCenter_Other, + NextBillingTime: db.NewVarcharTime(firstCreation), + }), + dbtest.NewCostCenter(t, db.CostCenter{ + ID: db.NewTeamAttributionID(attributionID), + SpendingLimit: 100, + CreationTime: db.NewVarcharTime(secondCreation), + BillingStrategy: db.CostCenter_Other, + NextBillingTime: db.NewVarcharTime(secondCreation), + }), + } + + dbtest.CreateCostCenters(t, conn, costCenters...) + + retrieved, err := mnr.ListLatestCostCentersWithBillingTimeBefore(context.Background(), db.CostCenter_Other, secondCreation.Add(7*24*time.Hour)) + require.NoError(t, err) + require.Len(t, retrieved, 1) + + requireCostCenterEqual(t, costCenters[1], retrieved[0]) + }) + + t.Run("returns results only when most recent cost center matches billing strategy", func(t *testing.T) { + conn := dbtest.ConnectForTests(t) + mnr := db.NewCostCenterManager(conn, db.DefaultSpendingLimit{ + ForTeams: 0, + ForUsers: 500, + }) + + attributionID := uuid.New().String() + firstCreation := time.Date(2022, 10, 10, 10, 10, 10, 10, time.UTC) + secondCreation := firstCreation.Add(24 * time.Hour) + + costCenters := []db.CostCenter{ + dbtest.NewCostCenter(t, db.CostCenter{ + ID: db.NewTeamAttributionID(attributionID), + SpendingLimit: 100, + CreationTime: db.NewVarcharTime(firstCreation), + BillingStrategy: db.CostCenter_Other, + NextBillingTime: db.NewVarcharTime(firstCreation), + }), + dbtest.NewCostCenter(t, db.CostCenter{ + ID: db.NewTeamAttributionID(attributionID), + SpendingLimit: 100, + CreationTime: db.NewVarcharTime(secondCreation), + BillingStrategy: db.CostCenter_Stripe, + }), + } + + dbtest.CreateCostCenters(t, conn, costCenters...) + + retrieved, err := mnr.ListLatestCostCentersWithBillingTimeBefore(context.Background(), db.CostCenter_Other, secondCreation.Add(7*24*time.Hour)) + require.NoError(t, err) + require.Len(t, retrieved, 0) + }) + +}