Skip to content

Commit cdb2d9e

Browse files
author
Laurie T. Malau
committed
[usage] Implement CreateStripeSubscription
1 parent 02e5789 commit cdb2d9e

File tree

5 files changed

+186
-17
lines changed

5 files changed

+186
-17
lines changed

components/usage/pkg/apiv1/billing.go

Lines changed: 85 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,20 @@ import (
2323
"gorm.io/gorm"
2424
)
2525

26-
func NewBillingService(stripeClient *stripe.Client, conn *gorm.DB, ccManager *db.CostCenterManager) *BillingService {
26+
func NewBillingService(stripeClient *stripe.Client, conn *gorm.DB, ccManager *db.CostCenterManager, stripePrices stripe.StripePrices) *BillingService {
2727
return &BillingService{
2828
stripeClient: stripeClient,
2929
conn: conn,
3030
ccManager: ccManager,
31+
stripePrices: stripePrices,
3132
}
3233
}
3334

3435
type BillingService struct {
3536
conn *gorm.DB
3637
stripeClient *stripe.Client
3738
ccManager *db.CostCenterManager
39+
stripePrices stripe.StripePrices
3840

3941
v1.UnimplementedBillingServiceServer
4042
}
@@ -165,6 +167,88 @@ func (s *BillingService) CreateStripeCustomer(ctx context.Context, req *v1.Creat
165167
}, nil
166168
}
167169

170+
func (s *BillingService) CreateStripeSubscription(ctx context.Context, req *v1.CreateStripeSubscriptionRequest) (*v1.CreateStripeSubscriptionResponse, error) {
171+
attributionID, err := db.ParseAttributionID(req.GetAttributionId())
172+
if err != nil {
173+
return nil, status.Errorf(codes.InvalidArgument, "Invalid attribution ID %s", attributionID)
174+
}
175+
176+
customer, err := s.GetStripeCustomer(ctx, &v1.GetStripeCustomerRequest{
177+
Identifier: &v1.GetStripeCustomerRequest_AttributionId{
178+
AttributionId: string(attributionID),
179+
},
180+
})
181+
if err != nil {
182+
return nil, err
183+
}
184+
185+
_, err = s.stripeClient.SetDefaultPaymentForCustomer(ctx, customer.Customer.Id, req.SetupIntentId)
186+
if err != nil {
187+
return nil, status.Errorf(codes.InvalidArgument, "Failed to set default payment for customer ID %s", customer.Customer.Id)
188+
}
189+
190+
stripeCustomer, err := s.stripeClient.GetCustomer(ctx, customer.Customer.Id)
191+
if err != nil {
192+
return nil, err
193+
}
194+
195+
priceID, err := getPriceIdentifier(attributionID, stripeCustomer, s)
196+
if err != nil {
197+
return nil, err
198+
}
199+
200+
var isAutomaticTaxSupported bool
201+
if stripeCustomer.Tax != nil {
202+
isAutomaticTaxSupported = stripeCustomer.Tax.AutomaticTax == "supported"
203+
}
204+
if !isAutomaticTaxSupported {
205+
log.Warnf("Automatic Stripe tax is not supported for customer %s", stripeCustomer.ID)
206+
}
207+
208+
subscription, err := s.stripeClient.CreateSubscription(ctx, stripeCustomer.ID, priceID, isAutomaticTaxSupported)
209+
if err != nil {
210+
return nil, status.Errorf(codes.Internal, "Failed to create subscription with customer ID %s", customer.Customer.Id)
211+
}
212+
213+
return &v1.CreateStripeSubscriptionResponse{
214+
Subscription: &v1.StripeSubscription{
215+
Id: subscription.ID,
216+
},
217+
}, nil
218+
}
219+
220+
func getPriceIdentifier(attributionID db.AttributionID, stripeCustomer *stripe_api.Customer, s *BillingService) (string, error) {
221+
preferredCurrency := stripeCustomer.Metadata["preferredCurrency"]
222+
if stripeCustomer.Metadata["preferredCurrency"] == "" {
223+
log.
224+
WithField("stripe_customer_id", stripeCustomer.ID).
225+
Warn("No preferred currency set. Defaulting to USD")
226+
}
227+
228+
entity, _ := attributionID.Values()
229+
230+
switch entity {
231+
case db.AttributionEntity_User:
232+
switch preferredCurrency {
233+
case "EUR":
234+
return s.stripePrices.IndividualUsagePriceIDs.EUR, nil
235+
default:
236+
return s.stripePrices.IndividualUsagePriceIDs.USD, nil
237+
}
238+
239+
case db.AttributionEntity_Team:
240+
switch preferredCurrency {
241+
case "EUR":
242+
return s.stripePrices.TeamUsagePriceIDs.EUR, nil
243+
default:
244+
return s.stripePrices.TeamUsagePriceIDs.USD, nil
245+
}
246+
247+
default:
248+
return "", status.Errorf(codes.InvalidArgument, "Invalid currency %s for customer ID %s", stripeCustomer.Metadata["preferredCurrency"], stripeCustomer.ID)
249+
}
250+
}
251+
168252
func (s *BillingService) ReconcileInvoices(ctx context.Context, in *v1.ReconcileInvoicesRequest) (*v1.ReconcileInvoicesResponse, error) {
169253
balances, err := db.ListBalance(ctx, s.conn)
170254
if err != nil {

components/usage/pkg/server/server.go

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,7 @@ type Config struct {
4444
DefaultSpendingLimit db.DefaultSpendingLimit `json:"defaultSpendingLimit"`
4545

4646
// StripePrices configure which Stripe Price IDs should be used
47-
StripePrices StripePrices `json:"stripePrices"`
48-
}
49-
50-
type PriceConfig struct {
51-
EUR string `json:"eur"`
52-
USD string `json:"usd"`
53-
}
54-
55-
type StripePrices struct {
56-
IndividualUsagePriceIDs PriceConfig `json:"individualUsagePriceIds"`
57-
TeamUsagePriceIDs PriceConfig `json:"teamUsagePriceIds"`
47+
StripePrices stripe.StripePrices `json:"stripePrices"`
5848
}
5949

6050
func Start(cfg Config, version string) error {
@@ -188,7 +178,7 @@ func registerGRPCServices(srv *baseserver.Server, conn *gorm.DB, stripeClient *s
188178
if stripeClient == nil {
189179
v1.RegisterBillingServiceServer(srv.GRPC(), &apiv1.BillingServiceNoop{})
190180
} else {
191-
v1.RegisterBillingServiceServer(srv.GRPC(), apiv1.NewBillingService(stripeClient, conn, ccManager))
181+
v1.RegisterBillingServiceServer(srv.GRPC(), apiv1.NewBillingService(stripeClient, conn, ccManager, cfg.StripePrices))
192182
}
193183
return nil
194184
}

components/usage/pkg/stripe/stripe.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@ type ClientConfig struct {
3737
SecretKey string `json:"secretKey"`
3838
}
3939

40+
type PriceConfig struct {
41+
EUR string `json:"eur"`
42+
USD string `json:"usd"`
43+
}
44+
45+
type StripePrices struct {
46+
IndividualUsagePriceIDs PriceConfig `json:"individualUsagePriceIds"`
47+
TeamUsagePriceIDs PriceConfig `json:"teamUsagePriceIds"`
48+
}
49+
4050
func ReadConfigFromFile(path string) (ClientConfig, error) {
4151
bytes, err := os.ReadFile(path)
4252
if err != nil {
@@ -320,6 +330,82 @@ func (c *Client) GetSubscriptionWithCustomer(ctx context.Context, subscriptionID
320330
return subscription, nil
321331
}
322332

333+
func (c *Client) CreateSubscription(ctx context.Context, customerID string, priceID string, isAutomaticTaxSupported bool) (*stripe.Subscription, error) {
334+
if customerID == "" {
335+
return nil, fmt.Errorf("no customerID specified")
336+
}
337+
if priceID == "" {
338+
return nil, fmt.Errorf("no priceID specified")
339+
}
340+
341+
startOfNextMonth := getStartOfNextMonth(time.Now())
342+
343+
params := &stripe.SubscriptionParams{
344+
Customer: stripe.String(customerID),
345+
Items: []*stripe.SubscriptionItemsParams{
346+
{
347+
Price: stripe.String(priceID),
348+
},
349+
},
350+
AutomaticTax: &stripe.SubscriptionAutomaticTaxParams{
351+
Enabled: stripe.Bool(isAutomaticTaxSupported),
352+
},
353+
BillingCycleAnchor: stripe.Int64(startOfNextMonth.Unix()),
354+
}
355+
356+
subscription, err := c.sc.Subscriptions.New(params)
357+
if err != nil {
358+
return nil, fmt.Errorf("failed to get subscription with customer ID %s", customerID)
359+
}
360+
361+
return subscription, err
362+
}
363+
364+
func getStartOfNextMonth(t time.Time) time.Time {
365+
currentYear, currentMonth, _ := t.Date()
366+
367+
firstOfMonth := time.Date(currentYear, currentMonth, 1, 0, 0, 0, 0, time.UTC)
368+
startOfNextMonth := firstOfMonth.AddDate(0, 1, 0)
369+
370+
return startOfNextMonth
371+
}
372+
373+
func (c *Client) SetDefaultPaymentForCustomer(ctx context.Context, customerID string, setupIntentId string) (*stripe.Customer, error) {
374+
if customerID == "" {
375+
return nil, fmt.Errorf("no customerID specified")
376+
}
377+
378+
if setupIntentId == "" {
379+
return nil, fmt.Errorf("no setupIntentID specified")
380+
}
381+
382+
setupIntent, err := c.sc.SetupIntents.Get(setupIntentId, &stripe.SetupIntentParams{
383+
Params: stripe.Params{
384+
Context: ctx,
385+
},
386+
})
387+
if err != nil {
388+
return nil, fmt.Errorf("Failed to retrieve setup intent with id %s", setupIntentId)
389+
}
390+
391+
paymentMethod, err := c.sc.PaymentMethods.Attach(setupIntent.PaymentMethod.ID, &stripe.PaymentMethodAttachParams{Customer: &customerID})
392+
if err != nil {
393+
return nil, fmt.Errorf("Failed to attach payment method to setup intent ID %s", setupIntentId)
394+
}
395+
396+
customer, _ := c.sc.Customers.Update(customerID, &stripe.CustomerParams{
397+
InvoiceSettings: &stripe.CustomerInvoiceSettingsParams{
398+
DefaultPaymentMethod: stripe.String(paymentMethod.ID)},
399+
Address: &stripe.AddressParams{
400+
Line1: stripe.String(paymentMethod.BillingDetails.Address.Line1),
401+
Country: stripe.String(paymentMethod.BillingDetails.Address.Country)}})
402+
if err != nil {
403+
return nil, fmt.Errorf("Failed to update customer with id %s", customerID)
404+
}
405+
406+
return customer, nil
407+
}
408+
323409
func GetAttributionID(ctx context.Context, customer *stripe.Customer) (db.AttributionID, error) {
324410
if customer == nil {
325411
log.Error("No customer information available for invoice.")

components/usage/pkg/stripe/stripe_test.go

Lines changed: 9 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ package stripe
66

77
import (
88
"fmt"
9-
"github.com/gitpod-io/gitpod/usage/pkg/db"
109
"testing"
10+
"time"
11+
12+
"github.com/gitpod-io/gitpod/usage/pkg/db"
1113

1214
"github.com/stretchr/testify/require"
1315
)
@@ -89,3 +91,9 @@ func TestCustomerQueriesForTeamIds_MultipleQueries(t *testing.T) {
8991
})
9092
}
9193
}
94+
95+
func TestStartOfNextMonth(t *testing.T) {
96+
ts := time.Date(2022, 10, 1, 0, 0, 0, 0, time.UTC)
97+
98+
require.Equal(t, time.Date(2022, 11, 1, 0, 0, 0, 0, time.UTC), getStartOfNextMonth(ts))
99+
}

install/installer/pkg/components/usage/configmap.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/gitpod-io/gitpod/common-go/baseserver"
1111
"github.com/gitpod-io/gitpod/usage/pkg/db"
1212
"github.com/gitpod-io/gitpod/usage/pkg/server"
13+
"github.com/gitpod-io/gitpod/usage/pkg/stripe"
1314

1415
"github.com/gitpod-io/gitpod/installer/pkg/common"
1516
"github.com/gitpod-io/gitpod/installer/pkg/config/v1/experimental"
@@ -39,12 +40,12 @@ func configmap(ctx *common.RenderContext) ([]runtime.Object, error) {
3940

4041
expWebAppConfig := getExperimentalWebAppConfig(ctx)
4142
if expWebAppConfig != nil && expWebAppConfig.Stripe != nil {
42-
cfg.StripePrices = server.StripePrices{
43-
IndividualUsagePriceIDs: server.PriceConfig{
43+
cfg.StripePrices = stripe.StripePrices{
44+
IndividualUsagePriceIDs: stripe.PriceConfig{
4445
EUR: expWebAppConfig.Stripe.IndividualUsagePriceIDs.EUR,
4546
USD: expWebAppConfig.Stripe.IndividualUsagePriceIDs.USD,
4647
},
47-
TeamUsagePriceIDs: server.PriceConfig{
48+
TeamUsagePriceIDs: stripe.PriceConfig{
4849
EUR: expWebAppConfig.Stripe.TeamUsagePriceIDs.EUR,
4950
USD: expWebAppConfig.Stripe.TeamUsagePriceIDs.USD,
5051
},

0 commit comments

Comments
 (0)