diff --git a/Sources/Persist/PersistMacro.swift b/Sources/Persist/PersistMacro.swift index dce3376..612870b 100644 --- a/Sources/Persist/PersistMacro.swift +++ b/Sources/Persist/PersistMacro.swift @@ -73,6 +73,15 @@ public macro Persist( cacheValue: Bool = false ) = #externalMacro(module: "PersistMacros", type: "Persist_UserDefaults_NoTransformer") +@attached(peer, names: suffixed(_storage), prefixed(`$`), suffixed(_cache)) +@attached(accessor) +public macro Persist( + key: String, + userDefaults: UserDefaults, + transformer: any ThrowingTransformer, + cacheValue: Bool = false +) = #externalMacro(module: "PersistMacros", type: "Persist_UserDefaults_NoTransformer") + import Foundation public struct UpdateListenerWrapper: Sendable { diff --git a/Sources/PersistMacros/PersistMacroMacro.swift b/Sources/PersistMacros/PersistMacroMacro.swift index 7acac62..d947ab9 100644 --- a/Sources/PersistMacros/PersistMacroMacro.swift +++ b/Sources/PersistMacros/PersistMacroMacro.swift @@ -41,11 +41,10 @@ extension UserDefaultsMacro { in context: some MacroExpansionContext, isMutating: Bool, isThrowing: Bool, - transformerModifier: TransformerModifier? + transformerModifier: TransformerModifier = [] ) throws -> [AccessorDeclSyntax] { guard let property = declaration.as(VariableDeclSyntax.self), let binding = property.bindings.first, - let identifier = binding.pattern.as(IdentifierPatternSyntax.self)?.identifier, binding.accessorBlock == nil else { throw HashableMacroDiagnosticMessage( @@ -62,33 +61,6 @@ extension UserDefaultsMacro { severity: .error ) } - - func unwrapBaseType(_ type: TypeSyntax) -> BaseType? { - if let optionalType = type.as(OptionalTypeSyntax.self) { - if let wrappedType = unwrapBaseType(optionalType.wrappedType) { - return .optional(wrappedType) - } else { - return nil - } - } else if let identifier = type.as(IdentifierTypeSyntax.self) { - return .identifier(identifier) - } else if let dictionary = type.as(DictionaryTypeSyntax.self) { - return .dictionary(dictionary) - } else if let array = type.as(ArrayTypeSyntax.self) { - // TODO: Check for things like [Int8], which are not supported. - return .array(array) - } else { - return nil - } - } - - guard let baseType = unwrapBaseType(typeAnnotation.type) else { - throw HashableMacroDiagnosticMessage( - id: "unsupported-type-annotation", - message: "@Persist does not support this type annotation.", - severity: .error - ) - } let labeledArguments = node.arguments?.as(LabeledExprListSyntax.self) ?? [] var keyExpression: ExprSyntax? @@ -118,8 +90,8 @@ extension UserDefaultsMacro { .text } .joined(separator: ".") - } else if transformerExpression != nil { - "\(identifier)_transformer" + } else if let transformerExpression { + "\(transformerExpression)" } else { nil } @@ -134,85 +106,175 @@ extension UserDefaultsMacro { let userDefaultsPropertyName = try userDefaultsAccessor(labeledArguments: labeledArguments) + if let transformer { + let getter: DeclSyntax + let setter: DeclSyntax + + let isOptional = typeAnnotation.type.is(OptionalTypeSyntax.self) + + var valueAccessor = """ + if let storedValue = \(userDefaultsPropertyName).object(forKey: \(keyExpression)) as? type(of: \(transformer)) .Output { + return try \(transformer).transformOutput(storedValue) + } + """ + + if isOptional { + valueAccessor += """ + + return nil + """ + } else if let defaultValue = binding.initializer?.value { + valueAccessor += """ + + return \(defaultValue) + """ + } else { + throw HashableMacroDiagnosticMessage( + id: "non-optional-unsupported", + message: "Non-optionals properties must have a default value.", + severity: .error + ) + } + + if transformerModifier.contains(.throwing) { + getter = """ + get throws { + \(raw: valueAccessor) + } + """ + setter = """ + \(raw: isMutating ? "mutating" : "nonmutating") set throws { + let transformedValue = try \(raw: transformer).transformInput(newValue) + \(raw: userDefaultsPropertyName).set(transformedValue, forKey: \(keyExpression)) + } + """ + } else { + getter = """ + get { + \(raw: valueAccessor) + } + """ + setter = """ + \(raw: isMutating ? "mutating" : "nonmutating") set { + let transformedValue = \(raw: transformer).transformInput(newValue) + \(raw: userDefaultsPropertyName).set(transformedValue, forKey: \(keyExpression)) + } + """ + } + + return [ + """ + \(getter) + \(setter) + """ + ] + } else { + func unwrapBaseType(_ type: TypeSyntax) -> BaseType? { + if let optionalType = type.as(OptionalTypeSyntax.self) { + if let wrappedType = unwrapBaseType(optionalType.wrappedType) { + return .optional(wrappedType) + } else { + return nil + } + } else if let identifier = type.as(IdentifierTypeSyntax.self) { + return .identifier(identifier) + } else if let dictionary = type.as(DictionaryTypeSyntax.self) { + return .dictionary(dictionary) + } else if let array = type.as(ArrayTypeSyntax.self) { + // TODO: Check for things like [Int8], which are not supported. + return .array(array) + } else { + return nil + } + } + + guard let baseType = unwrapBaseType(typeAnnotation.type) else { + throw HashableMacroDiagnosticMessage( + id: "unsupported-type-annotation", + message: "@Persist does not support this type annotation.", + severity: .error + ) + } - let valueSetter = """ - \(userDefaultsPropertyName).set(newValue, forKey: \(keyExpression)) - """ + let valueSetter = """ + \(userDefaultsPropertyName).set(newValue, forKey: \(keyExpression)) + """ - func valueAccessor(forBaseType baseType: BaseType) throws -> String { - switch baseType { - case .optional(let baseType): - try valueAccessor(forBaseType: baseType) - case .identifier(let identifierTypeSyntax): - switch identifierTypeSyntax.name.trimmed.text { - case "Bool", "Int", "UInt", "Int8", "UInt8", "Int16", "UInt16", "Int32", "UInt32", "Int64", "UInt64", "Float", "Double", "String", "Data", "Date", "CGFloat", "NSNumber": + func valueAccessor(forBaseType baseType: BaseType) throws -> String { + switch baseType { + case .optional(let baseType): + try valueAccessor(forBaseType: baseType) + case .identifier(let identifierTypeSyntax): + switch identifierTypeSyntax.name.trimmed.text { + case "Bool", "Int", "UInt", "Int8", "UInt8", "Int16", "UInt16", "Int32", "UInt32", "Int64", "UInt64", "Float", "Double", "String", "Data", "Date", "CGFloat", "NSNumber": + """ + if let value = \(userDefaultsPropertyName).object(forKey: \(keyExpression)) as? \(identifierTypeSyntax.name.trimmed) { + return value + } + """ + case "URL", "NSURL": + // URLs are actually stored as Data. We must use url(forKey:) to decode it. + """ + // The stored object must be data. This is how URLs are stored by user defaults and it + // prevents user defaults from trying to coerce e.g. a string to a URL by assuming that + // it uses the 'file' protocol. + if \(userDefaultsPropertyName).object(forKey: \(keyExpression)) is Data, let value = \(userDefaultsPropertyName).url(forKey: \(keyExpression)) { + return value + } + """ + default: + throw HashableMacroDiagnosticMessage( + id: "unsupported-type", + message: "The '\(identifierTypeSyntax.name.trimmed.text)' type is not supported. If it is a typealias provide the original type.", + severity: .error + ) + } + case .array(let arrayTypeSyntax): """ - if let value = \(userDefaultsPropertyName).object(forKey: \(keyExpression)) as? \(identifierTypeSyntax.name.trimmed) { + if let value = \(userDefaultsPropertyName).object(forKey: \(keyExpression)) as? \(arrayTypeSyntax) { return value } """ - case "URL", "NSURL": - // URLs are actually stored as Data. We must use url(forKey:) to decode it. + case .dictionary(let dictionaryTypeSyntax): """ - // The stored object must be data. This is how URLs are stored by user defaults and it - // prevents user defaults from trying to coerce e.g. a string to a URL by assuming that - // it uses the 'file' protocol. - if \(userDefaultsPropertyName).object(forKey: \(keyExpression)) is Data, let value = \(userDefaultsPropertyName).url(forKey: \(keyExpression)) { + if let value = \(userDefaultsPropertyName).object(forKey: \(keyExpression)) as? \(dictionaryTypeSyntax) { return value } """ - default: - throw HashableMacroDiagnosticMessage( - id: "unsupported-type", - message: "The '\(identifierTypeSyntax.name.trimmed.text)' type is not supported. If it is a typealias provide the original type.", - severity: .error - ) } - case .array(let arrayTypeSyntax): + } + + var valueAccessor: String = try valueAccessor(forBaseType: baseType) + + if baseType.isOptional { + valueAccessor += """ + + return nil """ - if let value = \(userDefaultsPropertyName).object(forKey: \(keyExpression)) as? \(arrayTypeSyntax) { - return value - } + } else if let defaultValue = binding.initializer?.value { + valueAccessor += """ + + return \(defaultValue) """ - case .dictionary(let dictionaryTypeSyntax): + } else { + throw HashableMacroDiagnosticMessage( + id: "non-optional-unsupported", + message: "Non-optionals properties must have a default value.", + severity: .error + ) + } + + return [ """ - if let value = \(userDefaultsPropertyName).object(forKey: \(keyExpression)) as? \(dictionaryTypeSyntax) { - return value + get { + \(raw: valueAccessor) + } + set { + \(raw: valueSetter) } """ - } - } - - var valueAccessor: String = try valueAccessor(forBaseType: baseType) - - if baseType.isOptional { - valueAccessor += """ - - return nil - """ - } else if let defaultValue = binding.initializer?.value { - valueAccessor += """ - - return \(defaultValue) - """ - } else { - throw HashableMacroDiagnosticMessage( - id: "non-optional-unsupported", - message: "Non-optionals properties must have a default value.", - severity: .error - ) + ] } - - return [ - """ - get { - \(raw: valueAccessor) - } - set { - \(raw: valueSetter) - } - """ - ] } static func userDefaultsAccessor(labeledArguments: LabeledExprListSyntax) throws -> String { @@ -381,7 +443,7 @@ public struct Persist_UserDefaults_NoTransformer: UserDefaultsMacro { in: context, isMutating: false, isThrowing: false, - transformerModifier: nil + transformerModifier: [] ) } diff --git a/Tests/PersistTests/PersistAPITests.swift b/Tests/PersistTests/PersistAPITests.swift index 7c29849..4e9621a 100644 --- a/Tests/PersistTests/PersistAPITests.swift +++ b/Tests/PersistTests/PersistAPITests.swift @@ -38,7 +38,7 @@ struct TestStruct: Sendable { // @Persist( // key: "transformed-key", -// storage: UserDefaultsStorage(.standard), +// userDefaults: .standard, // transformer: JSONTransformer() // ) // var transformedProperty: TaskPriority? diff --git a/Tests/PersistTests/PersistMacroTests.swift b/Tests/PersistTests/PersistMacroTests.swift index cd67916..e66af75 100644 --- a/Tests/PersistTests/PersistMacroTests.swift +++ b/Tests/PersistTests/PersistMacroTests.swift @@ -5,37 +5,30 @@ import XCTest // Macro implementations build for the host, so the corresponding module is not available when cross-compiling. Cross-compiled tests may still make use of the macro itself in end-to-end tests. #if canImport(PersistMacros) import PersistMacros - -let testMacros: [String: Macro.Type] = [ - "Persist": Persist_Storage_NoTransformer.self, -] #endif final class PersistMacroTests: XCTestCase { - func testMacro() throws { + func testUserDefaultsTransformer() throws { #if canImport(PersistMacros) assertMacroExpansion( """ struct Setting { - @Persist(key: "foo", storage: UserDefaultsStorage(.standard)) - var testProperty: Int = 0 + @Persist( + key: "transformed-key", + userDefaults: .standard, + transformer: JSONTransformer() + ) + var transformedProperty: TaskPriority? } """, expandedSource: """ struct Setting { - var testProperty: Int = 0 { - get { - testProperty_storage.getValue(forKey: "foo") ?? 0 - } - nonmutating set { - testProperty_storage.setValue(newValue, forKey: "foo") - } - } - - private let testProperty_storage = UserDefaultsStorage(.standard) + var transformedProperty: TaskPriority? } """, - macros: testMacros + macros: [ + "Persist": Persist_UserDefaults_NoTransformer.self, + ] ) #else throw XCTSkip("macros are only supported when running tests for the host platform")