diff --git a/billing/subscription/service.go b/billing/subscription/service.go index 8fb0ef95d..fac8aecc4 100644 --- a/billing/subscription/service.go +++ b/billing/subscription/service.go @@ -972,7 +972,7 @@ func (s *Service) getPlanFromSchedule(ctx context.Context, stripeSchedule *strip // CancelUpcomingPhase cancels the scheduled phase of the subscription func (s *Service) CancelUpcomingPhase(ctx context.Context, sub Subscription) error { - _, stripeSchedule, err := s.createOrGetSchedule(ctx, sub) + stripeSub, stripeSchedule, err := s.createOrGetSchedule(ctx, sub) if err != nil { return err } @@ -988,6 +988,12 @@ func (s *Service) CancelUpcomingPhase(ctx context.Context, sub Subscription) err var currency = string(stripeSchedule.Phases[0].Currency) var prorationBehavior = s.config.PlanChangeConfig.ProrationBehavior + var endBehavior = stripe.SubscriptionScheduleEndBehaviorRelease + + if stripeSub.Status == stripe.SubscriptionStatusTrialing { + endBehavior = stripe.SubscriptionScheduleEndBehaviorCancel + } + // update the phases _, err = s.stripeClient.SubscriptionSchedules.Update(stripeSchedule.ID, &stripe.SubscriptionScheduleParams{ Params: stripe.Params{ @@ -1005,7 +1011,7 @@ func (s *Service) CancelUpcomingPhase(ctx context.Context, sub Subscription) err }, }, }, - EndBehavior: stripe.String("release"), + EndBehavior: stripe.String(string(endBehavior)), ProrationBehavior: stripe.String(prorationBehavior), DefaultSettings: &stripe.SubscriptionScheduleDefaultSettingsParams{ CollectionMethod: stripe.String(s.config.PlanChangeConfig.CollectionMethod),