Skip to content

Use auth v2 in subscription activation flow #6366

New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Merged
merged 5 commits into from
Jul 11, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,14 @@ interface PrivacyProFeature {
* This flag will be used to select FE subscription messaging mode.
* The value is added into GetFeatureConfig to allow FE to select the mode.
*/
@Toggle.DefaultValue(DefaultFeatureValue.FALSE)
@Toggle.DefaultValue(DefaultFeatureValue.INTERNAL)
fun enableSubscriptionFlowsV2(): Toggle

/**
* Kill-switch for in-memory caching of auth v2 JWKs.
*/
@Toggle.DefaultValue(DefaultFeatureValue.TRUE)
fun authApiV2JwksCache(): Toggle
}

@ContributesBinding(AppScope::class)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -210,6 +210,11 @@ interface SubscriptionsManager {
*/
suspend fun signInV1(authToken: String)

/**
* Signs the user in using the provided v2 access and refresh tokens
*/
suspend fun signInV2(accessToken: String, refreshToken: String)

/**
* Signs the user out and deletes all the data from the device
*/
Expand Down Expand Up @@ -382,6 +387,21 @@ class RealSubscriptionsManager @Inject constructor(
}
}

override suspend fun signInV2(
accessToken: String,
refreshToken: String,
) {
val tokens = TokenPair(accessToken, refreshToken)
val jwks = authClient.getJwks()
saveTokens(validateTokens(tokens, jwks))
authRepository.purchaseToWaitingStatus()
try {
refreshSubscriptionData()
} catch (e: Exception) {
logcat { "Subs: error when refreshing subscription on v2 sign in" }
}
}

override suspend fun signOut() {
authRepository.getAccessTokenV2()?.run {
coroutineScope.launch { authClient.tryLogout(accessTokenV2 = jwt) }
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -18,9 +18,17 @@ package com.duckduckgo.subscriptions.impl.auth2

import android.net.Uri
import com.duckduckgo.appbuildconfig.api.AppBuildConfig
import com.duckduckgo.common.utils.CurrentTimeProvider
import com.duckduckgo.common.utils.DispatcherProvider
import com.duckduckgo.di.scopes.AppScope
import com.duckduckgo.subscriptions.impl.PrivacyProFeature
import com.squareup.anvil.annotations.ContributesBinding
import dagger.Lazy
import dagger.SingleInstanceIn
import java.time.Duration
import java.time.Instant
import javax.inject.Inject
import kotlinx.coroutines.withContext
import logcat.logcat
import retrofit2.HttpException
import retrofit2.Response
Expand Down Expand Up @@ -112,11 +120,17 @@ data class TokenPair(
)

@ContributesBinding(AppScope::class)
@SingleInstanceIn(AppScope::class)
class AuthClientImpl @Inject constructor(
private val authService: AuthService,
private val appBuildConfig: AppBuildConfig,
private val timeProvider: CurrentTimeProvider,
private val privacyProFeature: Lazy<PrivacyProFeature>,
private val dispatchers: DispatcherProvider,
) : AuthClient {

private var cachedJwks: CachedJwks? = null

override suspend fun authorize(codeChallenge: String): String {
val response = authService.authorize(
responseType = AUTH_V2_RESPONSE_TYPE,
Expand Down Expand Up @@ -183,8 +197,20 @@ class AuthClientImpl @Inject constructor(
)
}

override suspend fun getJwks(): String =
authService.jwks().string()
override suspend fun getJwks(): String {
val useCache = withContext(dispatchers.io()) {
privacyProFeature.get().authApiV2JwksCache().isEnabled()
}

return if (useCache) {
val cachedResult = cachedJwks?.takeIf { it.timestamp + JWKS_CACHE_DURATION > getCurrentTime() }?.jwks

cachedResult ?: authService.jwks().string()
.also { cachedJwks = CachedJwks(jwks = it, timestamp = getCurrentTime()) }
} else {
authService.jwks().string()
}
}

override suspend fun storeLogin(
sessionId: String,
Expand Down Expand Up @@ -242,6 +268,13 @@ class AuthClientImpl @Inject constructor(
}
}

private fun getCurrentTime(): Instant = Instant.ofEpochMilli(timeProvider.currentTimeMillis())

private data class CachedJwks(
val jwks: String,
val timestamp: Instant,
)

private companion object {
const val AUTH_V2_CLIENT_ID = "f4311287-0121-40e6-8bbd-85c36daf1837"
const val AUTH_V2_REDIRECT_URI = "com.duckduckgo:/authcb"
Expand All @@ -250,5 +283,6 @@ class AuthClientImpl @Inject constructor(
const val AUTH_V2_RESPONSE_TYPE = "code"
const val GRANT_TYPE_AUTHORIZATION_CODE = "authorization_code"
const val GRANT_TYPE_REFRESH_TOKEN = "refresh_token"
val JWKS_CACHE_DURATION: Duration = Duration.ofHours(1)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -69,6 +69,7 @@ class SubscriptionMessagingInterface @Inject constructor(
SubscriptionsHandler(),
GetSubscriptionMessage(subscriptionsManager, dispatcherProvider),
SetSubscriptionMessage(subscriptionsManager, appCoroutineScope, dispatcherProvider, pixelSender, subscriptionsChecker),
SetAuthTokensMessage(subscriptionsManager, appCoroutineScope, dispatcherProvider, pixelSender, subscriptionsChecker),
InformationalEventsMessage(subscriptionsManager, appCoroutineScope, pixelSender),
GetAccessTokenMessage(subscriptionsManager),
GetAuthAccessTokenMessage(subscriptionsManager),
Expand Down Expand Up @@ -222,6 +223,43 @@ class SubscriptionMessagingInterface @Inject constructor(
override val methods: List<String> = listOf("setSubscription")
}

inner class SetAuthTokensMessage(
private val subscriptionsManager: SubscriptionsManager,
@AppCoroutineScope private val appCoroutineScope: CoroutineScope,
private val dispatcherProvider: DispatcherProvider,
private val pixelSender: SubscriptionPixelSender,
private val subscriptionsChecker: SubscriptionsChecker,
) : JsMessageHandler {

override fun process(
jsMessage: JsMessage,
secret: String,
jsMessageCallback: JsMessageCallback?,
) {
val (accessToken, refreshToken) = try {
with(jsMessage.params) { getString("accessToken") to getString("refreshToken") }
} catch (e: Exception) {
logcat { "Error parsing the tokens" }
return
}

appCoroutineScope.launch(dispatcherProvider.io()) {
try {
subscriptionsManager.signInV2(accessToken, refreshToken)
subscriptionsChecker.runChecker()
pixelSender.reportRestoreUsingEmailSuccess()
pixelSender.reportSubscriptionActivated()
} catch (e: Exception) {
logcat { "Failed to set auth tokens" }
}
}
}

override val allowedDomains: List<String> = emptyList()
override val featureName: String = "useSubscription"
override val methods: List<String> = listOf("setAuthTokens")
}

private class InformationalEventsMessage(
private val subscriptionsManager: SubscriptionsManager,
@AppCoroutineScope private val appCoroutineScope: CoroutineScope,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -19,13 +19,15 @@ package com.duckduckgo.subscriptions.impl.ui
import androidx.lifecycle.ViewModel
import androidx.lifecycle.viewModelScope
import com.duckduckgo.anvil.annotations.ContributesViewModel
import com.duckduckgo.app.di.AppCoroutineScope
import com.duckduckgo.common.utils.DispatcherProvider
import com.duckduckgo.di.scopes.ActivityScope
import com.duckduckgo.subscriptions.api.SubscriptionStatus
import com.duckduckgo.subscriptions.impl.RealSubscriptionsManager.Companion.SUBSCRIPTION_NOT_FOUND_ERROR
import com.duckduckgo.subscriptions.impl.RealSubscriptionsManager.RecoverSubscriptionResult
import com.duckduckgo.subscriptions.impl.SubscriptionsChecker
import com.duckduckgo.subscriptions.impl.SubscriptionsManager
import com.duckduckgo.subscriptions.impl.auth2.AuthClient
import com.duckduckgo.subscriptions.impl.pixels.SubscriptionPixelSender
import com.duckduckgo.subscriptions.impl.repository.isExpired
import com.duckduckgo.subscriptions.impl.ui.RestoreSubscriptionViewModel.Command.Error
Expand All @@ -35,6 +37,7 @@ import com.duckduckgo.subscriptions.impl.ui.RestoreSubscriptionViewModel.Command
import com.duckduckgo.subscriptions.impl.ui.RestoreSubscriptionViewModel.Command.SubscriptionNotFound
import com.duckduckgo.subscriptions.impl.ui.RestoreSubscriptionViewModel.Command.Success
import javax.inject.Inject
import kotlinx.coroutines.CoroutineScope
import kotlinx.coroutines.channels.BufferOverflow.DROP_OLDEST
import kotlinx.coroutines.channels.Channel
import kotlinx.coroutines.flow.Flow
Expand All @@ -44,13 +47,16 @@ import kotlinx.coroutines.flow.launchIn
import kotlinx.coroutines.flow.onEach
import kotlinx.coroutines.flow.receiveAsFlow
import kotlinx.coroutines.launch
import logcat.logcat

@ContributesViewModel(ActivityScope::class)
class RestoreSubscriptionViewModel @Inject constructor(
private val subscriptionsManager: SubscriptionsManager,
private val subscriptionsChecker: SubscriptionsChecker,
private val dispatcherProvider: DispatcherProvider,
private val pixelSender: SubscriptionPixelSender,
private val authClient: AuthClient,
@AppCoroutineScope private val appCoroutineScope: CoroutineScope,
) : ViewModel() {

private val command = Channel<Command>(1, DROP_OLDEST)
Expand Down Expand Up @@ -106,6 +112,7 @@ class RestoreSubscriptionViewModel @Inject constructor(
viewModelScope.launch {
command.send(RestoreFromEmail)
}
warmUpJwksCache()
}

fun onSubscriptionRestoredFromEmail() = viewModelScope.launch {
Expand All @@ -116,6 +123,20 @@ class RestoreSubscriptionViewModel @Inject constructor(
}
}

/*
We'll need JWKs to validate auth tokens returned by FE after the user completes activation flow using email.
Prefetching them is optional, but it reduces the risk of failure when the network connection is unstable.
*/
private fun warmUpJwksCache() {
appCoroutineScope.launch {
try {
authClient.getJwks()
} catch (e: Exception) {
logcat { "Failed to warm-up JWKs cache, e: ${e.stackTraceToString()}" }
}
}
}

sealed class Command {
data object RestoreFromEmail : Command()
data object Success : Command()
Expand Down
Original file line number Diff line number Diff line change
@@ -1,15 +1,26 @@
package com.duckduckgo.subscriptions.impl.auth2

import android.annotation.SuppressLint
import androidx.test.ext.junit.runners.AndroidJUnit4
import com.duckduckgo.appbuildconfig.api.AppBuildConfig
import com.duckduckgo.common.test.CoroutineTestRule
import com.duckduckgo.common.utils.CurrentTimeProvider
import com.duckduckgo.feature.toggles.api.FakeFeatureToggleFactory
import com.duckduckgo.feature.toggles.api.Toggle.State
import com.duckduckgo.subscriptions.impl.PrivacyProFeature
import java.time.Duration
import java.time.Instant
import java.time.LocalDateTime
import kotlinx.coroutines.test.runTest
import okhttp3.Headers
import okhttp3.MediaType.Companion.toMediaTypeOrNull
import okhttp3.ResponseBody.Companion.toResponseBody
import org.junit.Assert.assertEquals
import org.junit.Assert.fail
import org.junit.Rule
import org.junit.Test
import org.junit.runner.RunWith
import org.mockito.Mockito.times
import org.mockito.kotlin.any
import org.mockito.kotlin.anyOrNull
import org.mockito.kotlin.doReturn
Expand All @@ -22,11 +33,23 @@ import retrofit2.Response
@RunWith(AndroidJUnit4::class)
class AuthClientImplTest {

@get:Rule
var coroutinesTestRule = CoroutineTestRule()

private val authService: AuthService = mock()
private val appBuildConfig: AppBuildConfig = mock { config ->
whenever(config.applicationId).thenReturn("com.duckduckgo.android")
}
private val authClient = AuthClientImpl(authService, appBuildConfig)
private val timeProvider = FakeTimeProvider()
private val privacyProFeature = FakeFeatureToggleFactory.create(PrivacyProFeature::class.java)

private val authClient = AuthClientImpl(
authService = authService,
appBuildConfig = appBuildConfig,
timeProvider = timeProvider,
privacyProFeature = { privacyProFeature },
dispatchers = coroutinesTestRule.testDispatcherProvider,
)

@Test
fun `when authorize success then returns sessionId parsed from Set-Cookie header`() = runTest {
Expand Down Expand Up @@ -264,4 +287,90 @@ class AuthClientImplTest {

authClient.tryLogout("fake v2 access token")
}

@Test
fun `when JWKS not cached then fetches from network`() = runTest {
val jwksJson = """{"keys": [{"kty": "RSA", "kid": "networkKey"}]}"""
val responseBody = jwksJson.toResponseBody("application/json".toMediaTypeOrNull())

whenever(authService.jwks()).thenReturn(responseBody)

val result = authClient.getJwks()

assertEquals(jwksJson, result)
verify(authService).jwks()
}

@Test
fun `when JWKS is cached and not expired then returns cached value`() = runTest {
val jwksJson = """{"keys": [{"kty": "RSA", "kid": "cachedKey"}]}"""
val responseBody = jwksJson.toResponseBody("application/json".toMediaTypeOrNull())

whenever(authService.jwks()).thenReturn(responseBody)

// Initial request
val first = authClient.getJwks()
assertEquals(jwksJson, first)

// Advance time just before expiration
timeProvider.currentTime += Duration.ofMinutes(59)

val second = authClient.getJwks()
assertEquals(jwksJson, second)

// Verify network call happened only once
verify(authService).jwks()
}

@Test
fun `when JWKS cache is expired then fetches new value`() = runTest {
val oldJwks = """{"keys": [{"kty": "RSA", "kid": "oldKey"}]}"""
val newJwks = """{"keys": [{"kty": "RSA", "kid": "newKey"}]}"""

whenever(authService.jwks())
.thenReturn(oldJwks.toResponseBody("application/json".toMediaTypeOrNull()))
.thenReturn(newJwks.toResponseBody("application/json".toMediaTypeOrNull()))

// Initial call → old value cached
val first = authClient.getJwks()
assertEquals(oldJwks, first)

// Advance time past expiration
timeProvider.currentTime += Duration.ofMinutes(61)

// Call again → should return new JWKS
val second = authClient.getJwks()
assertEquals(newJwks, second)

verify(authService, times(2)).jwks()
}

@SuppressLint("DenyListedApi")
@Test
fun `when JWKS cache is disabled then always fetches from network`() = runTest {
privacyProFeature.authApiV2JwksCache().setRawStoredState(State(false))

val jwks1 = """{"keys": [{"kty": "RSA", "kid": "key1"}]}"""
val jwks2 = """{"keys": [{"kty": "RSA", "kid": "key2"}]}"""

whenever(authService.jwks())
.thenReturn(jwks1.toResponseBody("application/json".toMediaTypeOrNull()))
.thenReturn(jwks2.toResponseBody("application/json".toMediaTypeOrNull()))

val first = authClient.getJwks()
val second = authClient.getJwks()

assertEquals(jwks1, first)
assertEquals(jwks2, second)

verify(authService, times(2)).jwks()
}

private class FakeTimeProvider : CurrentTimeProvider {
var currentTime: Instant = Instant.parse("2024-10-28T00:00:00Z")

override fun elapsedRealtime(): Long = throw UnsupportedOperationException()
override fun currentTimeMillis(): Long = currentTime.toEpochMilli()
override fun localDateTimeNow(): LocalDateTime = throw UnsupportedOperationException()
}
}
Loading
Loading