Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Subscription keychain sharing for access token #690

Merged
merged 15 commits into from
Mar 6, 2024
Merged
42 changes: 37 additions & 5 deletions Sources/Subscription/AccountManager.swift
Original file line number Diff line number Diff line change
Expand Up @@ -37,14 +37,23 @@ public protocol AccountManaging {
public class AccountManager: AccountManaging {

private let storage: AccountStorage
private let accessTokenStorage: SubscriptionTokenStorage

public weak var delegate: AccountManagerKeychainAccessDelegate?

public var isUserAuthenticated: Bool {
return accessToken != nil
}

public init(storage: AccountStorage = AccountKeychainStorage()) {
public convenience init(appGroup: String) {
let accessTokenStorage = SubscriptionTokenKeychainStorage(keychainType: .dataProtection(.named(appGroup)))
self.init(accessTokenStorage: accessTokenStorage)
}

public init(storage: AccountStorage = AccountKeychainStorage(),
accessTokenStorage: SubscriptionTokenStorage) {
self.storage = storage
self.accessTokenStorage = accessTokenStorage
}

public var authToken: String? {
Expand All @@ -63,7 +72,7 @@ public class AccountManager: AccountManaging {

public var accessToken: String? {
do {
return try storage.getAccessToken()
return try accessTokenStorage.getAccessToken()
} catch {
if let error = error as? AccountKeychainAccessError {
delegate?.accountManagerKeychainAccessFailed(accessType: .getAccessToken, error: error)
Expand Down Expand Up @@ -121,7 +130,7 @@ public class AccountManager: AccountManaging {
os_log(.info, log: .subscription, "[AccountManager] storeAccount")

do {
try storage.store(accessToken: token)
try accessTokenStorage.store(accessToken: token)
} catch {
if let error = error as? AccountKeychainAccessError {
delegate?.accountManagerKeychainAccessFailed(accessType: .storeAccessToken, error: error)
Expand Down Expand Up @@ -157,6 +166,7 @@ public class AccountManager: AccountManaging {

do {
try storage.clearAuthenticationState()
try accessTokenStorage.removeAccessToken()
} catch {
if let error = error as? AccountKeychainAccessError {
delegate?.accountManagerKeychainAccessFailed(accessType: .clearAuthenticationData, error: error)
Expand All @@ -168,6 +178,28 @@ public class AccountManager: AccountManaging {
NotificationCenter.default.post(name: .accountDidSignOut, object: self, userInfo: nil)
}

public func migrateAccessTokenToNewStore() throws {
var errorToThrow: Error?
do {
if let newAccessToken = try accessTokenStorage.getAccessToken() {
errorToThrow = MigrationError.noMigrationNeeded
} else if let oldAccessToken = try storage.getAccessToken() {
try accessTokenStorage.store(accessToken: oldAccessToken)
}
} catch {
errorToThrow = MigrationError.migrationFailed
}

if let errorToThrow {
throw errorToThrow
}
}

public enum MigrationError: Error {
case migrationFailed
case noMigrationNeeded
}

// MARK: -

public enum Entitlement: String {
Expand Down Expand Up @@ -238,12 +270,12 @@ public class AccountManager: AccountManaging {
}

@discardableResult
public static func checkForEntitlements(wait waitTime: Double, retry retryCount: Int) async -> Bool {
public static func checkForEntitlements(subscriptionAppGroup: String, wait waitTime: Double, retry retryCount: Int) async -> Bool {
var count = 0
var hasEntitlements = false

repeat {
switch await AccountManager().fetchEntitlements() {
switch await AccountManager(appGroup: subscriptionAppGroup).fetchEntitlements() {
case .success(let entitlements):
hasEntitlements = !entitlements.isEmpty
case .failure:
Expand Down
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
//
// SubscriptionTokenKeychainStorage.swift
//
// Copyright © 2024 DuckDuckGo. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

import Foundation

public class SubscriptionTokenKeychainStorage: SubscriptionTokenStorage {

private let keychainType: KeychainType

public init(keychainType: KeychainType = .dataProtection(.unspecified)) {
self.keychainType = keychainType
}

public func getAccessToken() throws -> String? {
try getString(forField: .accessToken)
}

public func store(accessToken: String) throws {
try set(string: accessToken, forField: .accessToken)
}

public func removeAccessToken() throws {
try deleteItem(forField: .accessToken)
}
}

private extension SubscriptionTokenKeychainStorage {

/*
Uses just kSecAttrService as the primary key, since we don't want to store
multiple accounts/tokens at the same time
*/
enum AccountKeychainField: String, CaseIterable {
case accessToken = "subscription.account.accessToken"
case testString = "subscription.account.testString"
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This is no longer needed ;)

Copy link
Contributor Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Oops!


var keyValue: String {
"com.duckduckgo" + "." + rawValue
}
}

func getString(forField field: AccountKeychainField) throws -> String? {
guard let data = try retrieveData(forField: field) else {
return nil
}

if let decodedString = String(data: data, encoding: String.Encoding.utf8) {
return decodedString
} else {
throw AccountKeychainAccessError.failedToDecodeKeychainDataAsString
}
}
func retrieveData(forField field: AccountKeychainField) throws -> Data? {
var query = defaultAttributes()
query[kSecAttrService] = field.keyValue
query[kSecMatchLimit] = kSecMatchLimitOne
query[kSecReturnData] = true

var item: CFTypeRef?
let status = SecItemCopyMatching(query as CFDictionary, &item)

if status == errSecSuccess {
if let existingItem = item as? Data {
return existingItem
} else {
throw AccountKeychainAccessError.failedToDecodeKeychainValueAsData
}
} else if status == errSecItemNotFound {
return nil
} else {
throw AccountKeychainAccessError.keychainLookupFailure(status)
}
}

func set(string: String, forField field: AccountKeychainField) throws {
guard let stringData = string.data(using: .utf8) else {
return
}

try store(data: stringData, forField: field)
}

func store(data: Data, forField field: AccountKeychainField) throws {
var query = defaultAttributes()
query[kSecAttrService] = field.keyValue
query[kSecAttrAccessible] = kSecAttrAccessibleAfterFirstUnlock
query[kSecValueData] = data

let status = SecItemAdd(query as CFDictionary, nil)

switch status {
case errSecSuccess:
return
case errSecDuplicateItem:
let updateStatus = updateData(data, forField: field)

if updateStatus != errSecSuccess {
throw AccountKeychainAccessError.keychainSaveFailure(status)
}
default:
throw AccountKeychainAccessError.keychainSaveFailure(status)
}
}

private func updateData(_ data: Data, forField field: AccountKeychainField) -> OSStatus {
var query = defaultAttributes()
query[kSecAttrService] = field.keyValue

let newAttributes = [
kSecValueData: data,
kSecAttrAccessible: kSecAttrAccessibleAfterFirstUnlock
] as [CFString: Any]

return SecItemUpdate(query as CFDictionary, newAttributes as CFDictionary)
}

func deleteItem(forField field: AccountKeychainField, useDataProtectionKeychain: Bool = true) throws {
let query = defaultAttributes()

let status = SecItemDelete(query as CFDictionary)

if status != errSecSuccess && status != errSecItemNotFound {
throw AccountKeychainAccessError.keychainDeleteFailure(status)
}
}

private func defaultAttributes() -> [CFString: Any] {
var attributes: [CFString: Any] = [
kSecClass: kSecClassGenericPassword,
kSecAttrSynchronizable: false
]

attributes.merge(keychainType.queryAttributes()) { $1 }

return attributes
}
}

public enum KeychainType {
case dataProtection(_ accessGroup: AccessGroup)

/// Uses the system keychain.
///
case system

case fileBased

public enum AccessGroup {
case unspecified
case named(_ name: String)
}

func queryAttributes() -> [CFString: Any] {
switch self {
case .dataProtection(let accessGroup):
switch accessGroup {
case .unspecified:
return [kSecUseDataProtectionKeychain: true]
case .named(let accessGroup):
return [
kSecUseDataProtectionKeychain: true,
kSecAttrAccessGroup: accessGroup
]
}
case .system:
return [kSecUseDataProtectionKeychain: false]
case .fileBased:
return [kSecUseDataProtectionKeychain: false]
}
}
}
25 changes: 25 additions & 0 deletions Sources/Subscription/AccountStorage/SubscriptionTokenStorage.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,25 @@
//
// SubscriptionTokenStorage.swift
//
// Copyright © 2024 DuckDuckGo. All rights reserved.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.
//

import Foundation

public protocol SubscriptionTokenStorage: AnyObject {
func getAccessToken() throws -> String?
func store(accessToken: String) throws
func removeAccessToken() throws
}
Original file line number Diff line number Diff line change
Expand Up @@ -29,10 +29,11 @@ public final class AppStoreAccountManagementFlow {
}

@discardableResult
public static func refreshAuthTokenIfNeeded() async -> Result<String, AppStoreAccountManagementFlow.Error> {
public static func refreshAuthTokenIfNeeded(subscriptionAppGroup: String) async -> Result<String, AppStoreAccountManagementFlow.Error> {
os_log(.info, log: .subscription, "[AppStoreAccountManagementFlow] refreshAuthTokenIfNeeded")
let accountManager = AccountManager(appGroup: subscriptionAppGroup)

var authToken = AccountManager().authToken ?? ""
var authToken = accountManager.authToken ?? ""

// Check if auth token if still valid
if case let .failure(validateTokenError) = await AuthService.validateToken(accessToken: authToken) {
Expand All @@ -43,9 +44,9 @@ public final class AppStoreAccountManagementFlow {

switch await AuthService.storeLogin(signature: lastTransactionJWSRepresentation) {
case .success(let response):
if response.externalID == AccountManager().externalID {
if response.externalID == accountManager.externalID {
authToken = response.authToken
AccountManager().storeAuthToken(token: authToken)
accountManager.storeAuthToken(token: authToken)
}
case .failure(let storeLoginError):
os_log(.error, log: .subscription, "[AppStoreAccountManagementFlow] storeLogin error: %{public}s", String(reflecting: storeLoginError))
Expand Down
12 changes: 6 additions & 6 deletions Sources/Subscription/Flows/AppStore/AppStorePurchaseFlow.swift
Original file line number Diff line number Diff line change
Expand Up @@ -57,14 +57,14 @@ public final class AppStorePurchaseFlow {
}

// swiftlint:disable cyclomatic_complexity
public static func purchaseSubscription(with subscriptionIdentifier: String, emailAccessToken: String?) async -> Result<Void, AppStorePurchaseFlow.Error> {
public static func purchaseSubscription(with subscriptionIdentifier: String, emailAccessToken: String?, subscriptionAppGroup: String) async -> Result<Void, AppStorePurchaseFlow.Error> {
os_log(.info, log: .subscription, "[AppStorePurchaseFlow] purchaseSubscription")

let accountManager = AccountManager()
let accountManager = AccountManager(appGroup: subscriptionAppGroup)
let externalID: String

// Check for past transactions most recent
switch await AppStoreRestoreFlow.restoreAccountFromPastPurchase() {
switch await AppStoreRestoreFlow.restoreAccountFromPastPurchase(appGroup: subscriptionAppGroup) {
case .success:
os_log(.info, log: .subscription, "[AppStorePurchaseFlow] purchaseSubscription -> restoreAccountFromPastPurchase: activeSubscriptionAlreadyPresent")
return .failure(.activeSubscriptionAlreadyPresent)
Expand Down Expand Up @@ -99,7 +99,7 @@ public final class AppStorePurchaseFlow {
return .success(())
case .failure(let error):
os_log(.error, log: .subscription, "[AppStorePurchaseFlow] purchaseSubscription error: %{public}s", String(reflecting: error))
AccountManager().signOut()
AccountManager(appGroup: subscriptionAppGroup).signOut()
switch error {
case .purchaseCancelledByUser:
return .failure(.cancelledByUser)
Expand All @@ -111,10 +111,10 @@ public final class AppStorePurchaseFlow {
// swiftlint:enable cyclomatic_complexity

@discardableResult
public static func completeSubscriptionPurchase() async -> Result<PurchaseUpdate, AppStorePurchaseFlow.Error> {
public static func completeSubscriptionPurchase(subscriptionAppGroup: String) async -> Result<PurchaseUpdate, AppStorePurchaseFlow.Error> {
os_log(.info, log: .subscription, "[AppStorePurchaseFlow] completeSubscriptionPurchase")

let result = await AccountManager.checkForEntitlements(wait: 2.0, retry: 20)
let result = await AccountManager.checkForEntitlements(subscriptionAppGroup: subscriptionAppGroup, wait: 2.0, retry: 20)

return result ? .success(PurchaseUpdate(type: "completed")) : .failure(.missingEntitlements)
}
Expand Down
4 changes: 2 additions & 2 deletions Sources/Subscription/Flows/AppStore/AppStoreRestoreFlow.swift
Original file line number Diff line number Diff line change
Expand Up @@ -34,15 +34,15 @@ public final class AppStoreRestoreFlow {
case subscriptionExpired(accountDetails: RestoredAccountDetails)
}

public static func restoreAccountFromPastPurchase() async -> Result<Void, AppStoreRestoreFlow.Error> {
public static func restoreAccountFromPastPurchase(appGroup: String) async -> Result<Void, AppStoreRestoreFlow.Error> {
Copy link
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

@quanganhdo Spotted this discrepancy, the param probably should be subscriptionAppGroup for consistency.

os_log(.info, log: .subscription, "[AppStoreRestoreFlow] restoreAccountFromPastPurchase")

guard let lastTransactionJWSRepresentation = await PurchaseManager.mostRecentTransaction() else {
os_log(.error, log: .subscription, "[AppStoreRestoreFlow] Error: missingAccountOrTransactions")
return .failure(.missingAccountOrTransactions)
}

let accountManager = AccountManager()
let accountManager = AccountManager(appGroup: appGroup)

// Do the store login to get short-lived token
let authToken: String
Expand Down
Loading
Loading