Skip to content

Commit b7a62fe

Browse files
author
Laurie T. Malau
committed
[usage] Implement CreateStripeSubscription
1 parent 44de257 commit b7a62fe

File tree

8 files changed

+280
-96
lines changed

8 files changed

+280
-96
lines changed

components/usage-api/go/v1/billing.pb.go

Lines changed: 83 additions & 73 deletions
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

components/usage-api/typescript/src/usage/v1/billing.pb.ts

Lines changed: 14 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -59,6 +59,7 @@ export interface CreateStripeCustomerResponse {
5959

6060
export interface CreateStripeSubscriptionRequest {
6161
attributionId: string;
62+
customerId: string;
6263
setupIntentId: string;
6364
usageLimit: number;
6465
}
@@ -623,19 +624,22 @@ export const CreateStripeCustomerResponse = {
623624
};
624625

625626
function createBaseCreateStripeSubscriptionRequest(): CreateStripeSubscriptionRequest {
626-
return { attributionId: "", setupIntentId: "", usageLimit: 0 };
627+
return { attributionId: "", customerId: "", setupIntentId: "", usageLimit: 0 };
627628
}
628629

629630
export const CreateStripeSubscriptionRequest = {
630631
encode(message: CreateStripeSubscriptionRequest, writer: _m0.Writer = _m0.Writer.create()): _m0.Writer {
631632
if (message.attributionId !== "") {
632633
writer.uint32(10).string(message.attributionId);
633634
}
635+
if (message.customerId !== "") {
636+
writer.uint32(18).string(message.customerId);
637+
}
634638
if (message.setupIntentId !== "") {
635-
writer.uint32(18).string(message.setupIntentId);
639+
writer.uint32(26).string(message.setupIntentId);
636640
}
637641
if (message.usageLimit !== 0) {
638-
writer.uint32(24).int64(message.usageLimit);
642+
writer.uint32(32).int64(message.usageLimit);
639643
}
640644
return writer;
641645
},
@@ -651,9 +655,12 @@ export const CreateStripeSubscriptionRequest = {
651655
message.attributionId = reader.string();
652656
break;
653657
case 2:
654-
message.setupIntentId = reader.string();
658+
message.customerId = reader.string();
655659
break;
656660
case 3:
661+
message.setupIntentId = reader.string();
662+
break;
663+
case 4:
657664
message.usageLimit = longToNumber(reader.int64() as Long);
658665
break;
659666
default:
@@ -667,6 +674,7 @@ export const CreateStripeSubscriptionRequest = {
667674
fromJSON(object: any): CreateStripeSubscriptionRequest {
668675
return {
669676
attributionId: isSet(object.attributionId) ? String(object.attributionId) : "",
677+
customerId: isSet(object.customerId) ? String(object.customerId) : "",
670678
setupIntentId: isSet(object.setupIntentId) ? String(object.setupIntentId) : "",
671679
usageLimit: isSet(object.usageLimit) ? Number(object.usageLimit) : 0,
672680
};
@@ -675,6 +683,7 @@ export const CreateStripeSubscriptionRequest = {
675683
toJSON(message: CreateStripeSubscriptionRequest): unknown {
676684
const obj: any = {};
677685
message.attributionId !== undefined && (obj.attributionId = message.attributionId);
686+
message.customerId !== undefined && (obj.customerId = message.customerId);
678687
message.setupIntentId !== undefined && (obj.setupIntentId = message.setupIntentId);
679688
message.usageLimit !== undefined && (obj.usageLimit = Math.round(message.usageLimit));
680689
return obj;
@@ -683,6 +692,7 @@ export const CreateStripeSubscriptionRequest = {
683692
fromPartial(object: DeepPartial<CreateStripeSubscriptionRequest>): CreateStripeSubscriptionRequest {
684693
const message = createBaseCreateStripeSubscriptionRequest();
685694
message.attributionId = object.attributionId ?? "";
695+
message.customerId = object.customerId ?? "";
686696
message.setupIntentId = object.setupIntentId ?? "";
687697
message.usageLimit = object.usageLimit ?? 0;
688698
return message;

components/usage-api/usage/v1/billing.proto

Lines changed: 3 additions & 2 deletions
Original file line numberDiff line numberDiff line change
@@ -78,8 +78,9 @@ message CreateStripeCustomerResponse {
7878

7979
message CreateStripeSubscriptionRequest {
8080
string attribution_id = 1;
81-
string setup_intent_id = 2;
82-
int64 usage_limit = 3;
81+
string customer_id = 2;
82+
string setup_intent_id = 3;
83+
int64 usage_limit = 4;
8384
}
8485

8586
message CreateStripeSubscriptionResponse {

components/usage/pkg/apiv1/billing.go

Lines changed: 73 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -23,18 +23,20 @@ import (
2323
"gorm.io/gorm"
2424
)
2525

26-
func NewBillingService(stripeClient *stripe.Client, conn *gorm.DB, ccManager *db.CostCenterManager) *BillingService {
26+
func NewBillingService(stripeClient *stripe.Client, conn *gorm.DB, ccManager *db.CostCenterManager, stripePrices stripe.StripePrices) *BillingService {
2727
return &BillingService{
2828
stripeClient: stripeClient,
2929
conn: conn,
3030
ccManager: ccManager,
31+
stripePrices: stripePrices,
3132
}
3233
}
3334

3435
type BillingService struct {
3536
conn *gorm.DB
3637
stripeClient *stripe.Client
3738
ccManager *db.CostCenterManager
39+
stripePrices stripe.StripePrices
3840

3941
v1.UnimplementedBillingServiceServer
4042
}
@@ -165,6 +167,76 @@ func (s *BillingService) CreateStripeCustomer(ctx context.Context, req *v1.Creat
165167
}, nil
166168
}
167169

170+
func (s *BillingService) CreateStripeSubscription(ctx context.Context, req *v1.CreateStripeSubscriptionRequest) (*v1.CreateStripeSubscriptionResponse, error) {
171+
attributionID, err := db.ParseAttributionID(req.GetAttributionId())
172+
if err != nil {
173+
return nil, status.Errorf(codes.InvalidArgument, "Invalid attribution ID %s", attributionID)
174+
}
175+
176+
customer, err := s.GetStripeCustomer(ctx, &v1.GetStripeCustomerRequest{Identifier: &v1.GetStripeCustomerRequest_AttributionId{AttributionId: string(attributionID)}})
177+
if err != nil {
178+
return nil, status.Errorf(codes.NotFound, "Stripe customer with attribution ID %s not found", attributionID)
179+
}
180+
181+
_, err = s.stripeClient.SetDefaultPaymentForCustomer(ctx, customer.Customer.Id, req.SetupIntentId)
182+
if err != nil {
183+
return nil, status.Errorf(codes.InvalidArgument, "Failed to set default payment for customer ID %s", customer.Customer.Id)
184+
}
185+
186+
stripeCustomer, err := s.stripeClient.GetCustomer(ctx, customer.Customer.Id)
187+
if err != nil {
188+
return nil, err
189+
}
190+
191+
priceIdentifier := getPriceIdentifier(attributionID, stripeCustomer, s)
192+
if priceIdentifier == "" {
193+
return nil, status.Errorf(codes.InvalidArgument, "Invalid currency %s for customer ID %s", stripeCustomer.Metadata["preferredCurrency"], stripeCustomer.ID)
194+
}
195+
196+
var isAutomaticTaxSupported bool
197+
if stripeCustomer.Tax != nil {
198+
isAutomaticTaxSupported = stripeCustomer.Tax.AutomaticTax == "supported"
199+
}
200+
if !isAutomaticTaxSupported {
201+
log.Warnf("Automatic Stripe tax is not supported for customer %s", stripeCustomer.ID)
202+
}
203+
204+
subscription, err := s.stripeClient.CreateSubscription(ctx, stripeCustomer.ID, priceIdentifier, isAutomaticTaxSupported)
205+
if err != nil {
206+
return nil, status.Errorf(codes.Aborted, "Failed to create subscription with customer ID %s", customer.Customer.Id)
207+
}
208+
209+
return &v1.CreateStripeSubscriptionResponse{
210+
Subscription: &v1.StripeSubscription{
211+
Id: subscription.ID,
212+
},
213+
}, nil
214+
}
215+
216+
func getPriceIdentifier(attributionID db.AttributionID, stripeCustomer *stripe_api.Customer, s *BillingService) string {
217+
priceIdentifier := ""
218+
219+
if stripeCustomer.Metadata["preferredCurrency"] == "" {
220+
log.WithField("stripe_customer_id", stripeCustomer.ID).WithField("stripe_preferred_currency", stripeCustomer.Metadata["preferredCurrency"]).Warn("No preferred currency set. Defaulting to X")
221+
}
222+
223+
if attributionID.IsEntity("team") {
224+
if stripeCustomer.Metadata["preferredCurrency"] == "EUR" {
225+
priceIdentifier = s.stripePrices.TeamUsagePriceIDs.EUR
226+
} else {
227+
priceIdentifier = s.stripePrices.TeamUsagePriceIDs.USD
228+
}
229+
}
230+
if attributionID.IsEntity("user") {
231+
if stripeCustomer.Metadata["preferredCurrency"] == "EUR" {
232+
priceIdentifier = s.stripePrices.IndividualUsagePriceIDs.EUR
233+
} else {
234+
priceIdentifier = s.stripePrices.IndividualUsagePriceIDs.USD
235+
}
236+
}
237+
return priceIdentifier
238+
}
239+
168240
func (s *BillingService) ReconcileInvoices(ctx context.Context, in *v1.ReconcileInvoicesRequest) (*v1.ReconcileInvoicesResponse, error) {
169241
balances, err := db.ListBalance(ctx, s.conn)
170242
if err != nil {

components/usage/pkg/server/server.go

Lines changed: 2 additions & 12 deletions
Original file line numberDiff line numberDiff line change
@@ -44,17 +44,7 @@ type Config struct {
4444
DefaultSpendingLimit db.DefaultSpendingLimit `json:"defaultSpendingLimit"`
4545

4646
// StripePrices configure which Stripe Price IDs should be used
47-
StripePrices StripePrices `json:"stripePrices"`
48-
}
49-
50-
type PriceConfig struct {
51-
EUR string `json:"eur"`
52-
USD string `json:"usd"`
53-
}
54-
55-
type StripePrices struct {
56-
IndividualUsagePriceIDs PriceConfig `json:"individualUsagePriceIds"`
57-
TeamUsagePriceIDs PriceConfig `json:"teamUsagePriceIds"`
47+
StripePrices stripe.StripePrices `json:"stripePrices"`
5848
}
5949

6050
func Start(cfg Config, version string) error {
@@ -188,7 +178,7 @@ func registerGRPCServices(srv *baseserver.Server, conn *gorm.DB, stripeClient *s
188178
if stripeClient == nil {
189179
v1.RegisterBillingServiceServer(srv.GRPC(), &apiv1.BillingServiceNoop{})
190180
} else {
191-
v1.RegisterBillingServiceServer(srv.GRPC(), apiv1.NewBillingService(stripeClient, conn, ccManager))
181+
v1.RegisterBillingServiceServer(srv.GRPC(), apiv1.NewBillingService(stripeClient, conn, ccManager, cfg.StripePrices))
192182
}
193183
return nil
194184
}

components/usage/pkg/stripe/stripe.go

Lines changed: 86 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -37,6 +37,16 @@ type ClientConfig struct {
3737
SecretKey string `json:"secretKey"`
3838
}
3939

40+
type PriceConfig struct {
41+
EUR string `json:"eur"`
42+
USD string `json:"usd"`
43+
}
44+
45+
type StripePrices struct {
46+
IndividualUsagePriceIDs PriceConfig `json:"individualUsagePriceIds"`
47+
TeamUsagePriceIDs PriceConfig `json:"teamUsagePriceIds"`
48+
}
49+
4050
func ReadConfigFromFile(path string) (ClientConfig, error) {
4151
bytes, err := os.ReadFile(path)
4252
if err != nil {
@@ -320,6 +330,82 @@ func (c *Client) GetSubscriptionWithCustomer(ctx context.Context, subscriptionID
320330
return subscription, nil
321331
}
322332

333+
func (c *Client) CreateSubscription(ctx context.Context, customerID string, priceID string, isAutomaticTaxSupported bool) (*stripe.Subscription, error) {
334+
if customerID == "" {
335+
return nil, fmt.Errorf("no customerID specified")
336+
}
337+
if priceID == "" {
338+
return nil, fmt.Errorf("no priceID specified")
339+
}
340+
341+
startOfNextMonth := getStartOfNextMonth(time.Now())
342+
343+
params := &stripe.SubscriptionParams{
344+
Customer: &customerID,
345+
Items: []*stripe.SubscriptionItemsParams{
346+
{
347+
Price: &priceID,
348+
},
349+
},
350+
AutomaticTax: &stripe.SubscriptionAutomaticTaxParams{
351+
Enabled: &isAutomaticTaxSupported,
352+
},
353+
BillingCycleAnchor: stripe.Int64(startOfNextMonth.Unix()),
354+
}
355+
356+
subscription, err := c.sc.Subscriptions.New(params)
357+
if err != nil {
358+
return nil, fmt.Errorf("failed to get subscription with customer ID %s", customerID)
359+
}
360+
361+
return subscription, err
362+
}
363+
364+
func getStartOfNextMonth(t time.Time) time.Time {
365+
currentYear, currentMonth, _ := t.Date()
366+
367+
firstOfMonth := time.Date(currentYear, currentMonth, 1, 0, 0, 0, 0, time.UTC)
368+
startOfNextMonth := firstOfMonth.AddDate(0, 1, 0)
369+
370+
return startOfNextMonth
371+
}
372+
373+
func (c *Client) SetDefaultPaymentForCustomer(ctx context.Context, customerID string, setupIntentId string) (*stripe.Customer, error) {
374+
if customerID == "" {
375+
return nil, fmt.Errorf("no customerID specified")
376+
}
377+
378+
if setupIntentId == "" {
379+
return nil, fmt.Errorf("no setupIntentID specified")
380+
}
381+
382+
setupIntent, err := c.sc.SetupIntents.Get(setupIntentId, &stripe.SetupIntentParams{
383+
Params: stripe.Params{
384+
Context: ctx,
385+
},
386+
})
387+
if err != nil {
388+
return nil, fmt.Errorf("Failed to retrieve setup intent with id %s", setupIntentId)
389+
}
390+
391+
paymentMethod, err := c.sc.PaymentMethods.Attach(setupIntent.PaymentMethod.ID, &stripe.PaymentMethodAttachParams{Customer: &customerID})
392+
if err != nil {
393+
return nil, fmt.Errorf("Failed to attach payment method to setup intent ID %s", setupIntentId)
394+
}
395+
396+
customer, _ := c.sc.Customers.Update(customerID, &stripe.CustomerParams{
397+
InvoiceSettings: &stripe.CustomerInvoiceSettingsParams{
398+
DefaultPaymentMethod: &paymentMethod.ID},
399+
Address: &stripe.AddressParams{
400+
Line1: &paymentMethod.BillingDetails.Address.Line1,
401+
Country: &paymentMethod.BillingDetails.Address.Country}})
402+
if err != nil {
403+
return nil, fmt.Errorf("Failed to update customer with id %s", customerID)
404+
}
405+
406+
return customer, nil
407+
}
408+
323409
func GetAttributionID(ctx context.Context, customer *stripe.Customer) (db.AttributionID, error) {
324410
if customer == nil {
325411
log.Error("No customer information available for invoice.")

components/usage/pkg/stripe/stripe_test.go

Lines changed: 15 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -6,8 +6,10 @@ package stripe
66

77
import (
88
"fmt"
9-
"github.com/gitpod-io/gitpod/usage/pkg/db"
109
"testing"
10+
"time"
11+
12+
"github.com/gitpod-io/gitpod/usage/pkg/db"
1113

1214
"github.com/stretchr/testify/require"
1315
)
@@ -89,3 +91,15 @@ func TestCustomerQueriesForTeamIds_MultipleQueries(t *testing.T) {
8991
})
9092
}
9193
}
94+
95+
func TestStartOfNextMonth(t *testing.T) {
96+
now := time.Now()
97+
currentYear, currentMonth, _ := now.Date()
98+
firstOfMonth := time.Date(currentYear, currentMonth, 1, 0, 0, 0, 0, time.UTC)
99+
nextMonth := firstOfMonth.AddDate(0, 1, 0).Month()
100+
101+
expectedStartOfNextMonth := time.Time(time.Date(currentYear, nextMonth, 1, 0, 0, 0, 0, time.UTC))
102+
actualStartOfNextMonth := getStartOfNextMonth(now)
103+
104+
require.Equal(t, actualStartOfNextMonth, expectedStartOfNextMonth)
105+
}

install/installer/pkg/components/usage/configmap.go

Lines changed: 4 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -10,6 +10,7 @@ import (
1010
"github.com/gitpod-io/gitpod/common-go/baseserver"
1111
"github.com/gitpod-io/gitpod/usage/pkg/db"
1212
"github.com/gitpod-io/gitpod/usage/pkg/server"
13+
"github.com/gitpod-io/gitpod/usage/pkg/stripe"
1314

1415
"github.com/gitpod-io/gitpod/installer/pkg/common"
1516
"github.com/gitpod-io/gitpod/installer/pkg/config/v1/experimental"
@@ -39,12 +40,12 @@ func configmap(ctx *common.RenderContext) ([]runtime.Object, error) {
3940

4041
expWebAppConfig := getExperimentalWebAppConfig(ctx)
4142
if expWebAppConfig != nil && expWebAppConfig.Stripe != nil {
42-
cfg.StripePrices = server.StripePrices{
43-
IndividualUsagePriceIDs: server.PriceConfig{
43+
cfg.StripePrices = stripe.StripePrices{
44+
IndividualUsagePriceIDs: stripe.PriceConfig{
4445
EUR: expWebAppConfig.Stripe.IndividualUsagePriceIDs.EUR,
4546
USD: expWebAppConfig.Stripe.IndividualUsagePriceIDs.USD,
4647
},
47-
TeamUsagePriceIDs: server.PriceConfig{
48+
TeamUsagePriceIDs: stripe.PriceConfig{
4849
EUR: expWebAppConfig.Stripe.TeamUsagePriceIDs.EUR,
4950
USD: expWebAppConfig.Stripe.TeamUsagePriceIDs.USD,
5051
},

0 commit comments

Comments
 (0)