Skip to content

Commit

Permalink
Very WIP support for transformers
Browse files Browse the repository at this point in the history
  • Loading branch information
JosephDuffy committed Oct 11, 2024
1 parent bc94381 commit 69c4d6d
Show file tree
Hide file tree
Showing 4 changed files with 177 additions and 113 deletions.
9 changes: 9 additions & 0 deletions Sources/Persist/PersistMacro.swift
Original file line number Diff line number Diff line change
Expand Up @@ -73,6 +73,15 @@ public macro Persist<Root>(
cacheValue: Bool = false
) = #externalMacro(module: "PersistMacros", type: "Persist_UserDefaults_NoTransformer")

@attached(peer, names: suffixed(_storage), prefixed(`$`), suffixed(_cache))
@attached(accessor)
public macro Persist<Input, Output>(
key: String,
userDefaults: UserDefaults,
transformer: any ThrowingTransformer<Input, Output>,
cacheValue: Bool = false
) = #externalMacro(module: "PersistMacros", type: "Persist_UserDefaults_NoTransformer")

import Foundation

public struct UpdateListenerWrapper<Value>: Sendable {
Expand Down
250 changes: 156 additions & 94 deletions Sources/PersistMacros/PersistMacroMacro.swift
Original file line number Diff line number Diff line change
Expand Up @@ -41,11 +41,10 @@ extension UserDefaultsMacro {
in context: some MacroExpansionContext,
isMutating: Bool,
isThrowing: Bool,
transformerModifier: TransformerModifier?
transformerModifier: TransformerModifier = []
) throws -> [AccessorDeclSyntax] {
guard let property = declaration.as(VariableDeclSyntax.self),
let binding = property.bindings.first,
let identifier = binding.pattern.as(IdentifierPatternSyntax.self)?.identifier,
binding.accessorBlock == nil
else {
throw HashableMacroDiagnosticMessage(
Expand All @@ -62,33 +61,6 @@ extension UserDefaultsMacro {
severity: .error
)
}

func unwrapBaseType(_ type: TypeSyntax) -> BaseType? {
if let optionalType = type.as(OptionalTypeSyntax.self) {
if let wrappedType = unwrapBaseType(optionalType.wrappedType) {
return .optional(wrappedType)
} else {
return nil
}
} else if let identifier = type.as(IdentifierTypeSyntax.self) {
return .identifier(identifier)
} else if let dictionary = type.as(DictionaryTypeSyntax.self) {
return .dictionary(dictionary)
} else if let array = type.as(ArrayTypeSyntax.self) {
// TODO: Check for things like [Int8], which are not supported.
return .array(array)
} else {
return nil
}
}

guard let baseType = unwrapBaseType(typeAnnotation.type) else {
throw HashableMacroDiagnosticMessage(
id: "unsupported-type-annotation",
message: "@Persist does not support this type annotation.",
severity: .error
)
}
let labeledArguments = node.arguments?.as(LabeledExprListSyntax.self) ?? []

var keyExpression: ExprSyntax?
Expand Down Expand Up @@ -118,8 +90,8 @@ extension UserDefaultsMacro {
.text
}
.joined(separator: ".")
} else if transformerExpression != nil {
"\(identifier)_transformer"
} else if let transformerExpression {
"\(transformerExpression)"
} else {
nil
}
Expand All @@ -134,85 +106,175 @@ extension UserDefaultsMacro {

let userDefaultsPropertyName = try userDefaultsAccessor(labeledArguments: labeledArguments)

if let transformer {
let getter: DeclSyntax
let setter: DeclSyntax

let isOptional = typeAnnotation.type.is(OptionalTypeSyntax.self)

var valueAccessor = """
if let storedValue = \(userDefaultsPropertyName).object(forKey: \(keyExpression)) as? type(of: \(transformer)) .Output {
return try \(transformer).transformOutput(storedValue)
}
"""

if isOptional {
valueAccessor += """
return nil
"""
} else if let defaultValue = binding.initializer?.value {
valueAccessor += """
return \(defaultValue)
"""
} else {
throw HashableMacroDiagnosticMessage(
id: "non-optional-unsupported",
message: "Non-optionals properties must have a default value.",
severity: .error
)
}

if transformerModifier.contains(.throwing) {
getter = """
get throws {
\(raw: valueAccessor)
}
"""
setter = """
\(raw: isMutating ? "mutating" : "nonmutating") set throws {
let transformedValue = try \(raw: transformer).transformInput(newValue)
\(raw: userDefaultsPropertyName).set(transformedValue, forKey: \(keyExpression))
}
"""
} else {
getter = """
get {
\(raw: valueAccessor)
}
"""
setter = """
\(raw: isMutating ? "mutating" : "nonmutating") set {
let transformedValue = \(raw: transformer).transformInput(newValue)
\(raw: userDefaultsPropertyName).set(transformedValue, forKey: \(keyExpression))
}
"""
}

return [
"""
\(getter)
\(setter)
"""
]
} else {
func unwrapBaseType(_ type: TypeSyntax) -> BaseType? {
if let optionalType = type.as(OptionalTypeSyntax.self) {
if let wrappedType = unwrapBaseType(optionalType.wrappedType) {
return .optional(wrappedType)
} else {
return nil
}
} else if let identifier = type.as(IdentifierTypeSyntax.self) {
return .identifier(identifier)
} else if let dictionary = type.as(DictionaryTypeSyntax.self) {
return .dictionary(dictionary)
} else if let array = type.as(ArrayTypeSyntax.self) {
// TODO: Check for things like [Int8], which are not supported.
return .array(array)
} else {
return nil
}
}

guard let baseType = unwrapBaseType(typeAnnotation.type) else {
throw HashableMacroDiagnosticMessage(
id: "unsupported-type-annotation",
message: "@Persist does not support this type annotation.",
severity: .error
)
}

let valueSetter = """
\(userDefaultsPropertyName).set(newValue, forKey: \(keyExpression))
"""
let valueSetter = """
\(userDefaultsPropertyName).set(newValue, forKey: \(keyExpression))
"""

func valueAccessor(forBaseType baseType: BaseType) throws -> String {
switch baseType {
case .optional(let baseType):
try valueAccessor(forBaseType: baseType)
case .identifier(let identifierTypeSyntax):
switch identifierTypeSyntax.name.trimmed.text {
case "Bool", "Int", "UInt", "Int8", "UInt8", "Int16", "UInt16", "Int32", "UInt32", "Int64", "UInt64", "Float", "Double", "String", "Data", "Date", "CGFloat", "NSNumber":
func valueAccessor(forBaseType baseType: BaseType) throws -> String {
switch baseType {
case .optional(let baseType):
try valueAccessor(forBaseType: baseType)
case .identifier(let identifierTypeSyntax):
switch identifierTypeSyntax.name.trimmed.text {
case "Bool", "Int", "UInt", "Int8", "UInt8", "Int16", "UInt16", "Int32", "UInt32", "Int64", "UInt64", "Float", "Double", "String", "Data", "Date", "CGFloat", "NSNumber":
"""
if let value = \(userDefaultsPropertyName).object(forKey: \(keyExpression)) as? \(identifierTypeSyntax.name.trimmed) {
return value
}
"""
case "URL", "NSURL":
// URLs are actually stored as Data. We must use url(forKey:) to decode it.
"""
// The stored object must be data. This is how URLs are stored by user defaults and it
// prevents user defaults from trying to coerce e.g. a string to a URL by assuming that
// it uses the 'file' protocol.
if \(userDefaultsPropertyName).object(forKey: \(keyExpression)) is Data, let value = \(userDefaultsPropertyName).url(forKey: \(keyExpression)) {
return value
}
"""
default:
throw HashableMacroDiagnosticMessage(
id: "unsupported-type",
message: "The '\(identifierTypeSyntax.name.trimmed.text)' type is not supported. If it is a typealias provide the original type.",
severity: .error
)
}
case .array(let arrayTypeSyntax):
"""
if let value = \(userDefaultsPropertyName).object(forKey: \(keyExpression)) as? \(identifierTypeSyntax.name.trimmed) {
if let value = \(userDefaultsPropertyName).object(forKey: \(keyExpression)) as? \(arrayTypeSyntax) {
return value
}
"""
case "URL", "NSURL":
// URLs are actually stored as Data. We must use url(forKey:) to decode it.
case .dictionary(let dictionaryTypeSyntax):
"""
// The stored object must be data. This is how URLs are stored by user defaults and it
// prevents user defaults from trying to coerce e.g. a string to a URL by assuming that
// it uses the 'file' protocol.
if \(userDefaultsPropertyName).object(forKey: \(keyExpression)) is Data, let value = \(userDefaultsPropertyName).url(forKey: \(keyExpression)) {
if let value = \(userDefaultsPropertyName).object(forKey: \(keyExpression)) as? \(dictionaryTypeSyntax) {
return value
}
"""
default:
throw HashableMacroDiagnosticMessage(
id: "unsupported-type",
message: "The '\(identifierTypeSyntax.name.trimmed.text)' type is not supported. If it is a typealias provide the original type.",
severity: .error
)
}
case .array(let arrayTypeSyntax):
}

var valueAccessor: String = try valueAccessor(forBaseType: baseType)

if baseType.isOptional {
valueAccessor += """
return nil
"""
if let value = \(userDefaultsPropertyName).object(forKey: \(keyExpression)) as? \(arrayTypeSyntax) {
return value
}
} else if let defaultValue = binding.initializer?.value {
valueAccessor += """
return \(defaultValue)
"""
case .dictionary(let dictionaryTypeSyntax):
} else {
throw HashableMacroDiagnosticMessage(
id: "non-optional-unsupported",
message: "Non-optionals properties must have a default value.",
severity: .error
)
}

return [
"""
if let value = \(userDefaultsPropertyName).object(forKey: \(keyExpression)) as? \(dictionaryTypeSyntax) {
return value
get {
\(raw: valueAccessor)
}
set {
\(raw: valueSetter)
}
"""
}
}

var valueAccessor: String = try valueAccessor(forBaseType: baseType)

if baseType.isOptional {
valueAccessor += """
return nil
"""
} else if let defaultValue = binding.initializer?.value {
valueAccessor += """
return \(defaultValue)
"""
} else {
throw HashableMacroDiagnosticMessage(
id: "non-optional-unsupported",
message: "Non-optionals properties must have a default value.",
severity: .error
)
]
}

return [
"""
get {
\(raw: valueAccessor)
}
set {
\(raw: valueSetter)
}
"""
]
}

static func userDefaultsAccessor(labeledArguments: LabeledExprListSyntax) throws -> String {
Expand Down Expand Up @@ -381,7 +443,7 @@ public struct Persist_UserDefaults_NoTransformer: UserDefaultsMacro {
in: context,
isMutating: false,
isThrowing: false,
transformerModifier: nil
transformerModifier: []
)
}

Expand Down
2 changes: 1 addition & 1 deletion Tests/PersistTests/PersistAPITests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -38,7 +38,7 @@ struct TestStruct: Sendable {

// @Persist(
// key: "transformed-key",
// storage: UserDefaultsStorage(.standard),
// userDefaults: .standard,
// transformer: JSONTransformer<TaskPriority>()
// )
// var transformedProperty: TaskPriority?
Expand Down
29 changes: 11 additions & 18 deletions Tests/PersistTests/PersistMacroTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -5,37 +5,30 @@ import XCTest
// Macro implementations build for the host, so the corresponding module is not available when cross-compiling. Cross-compiled tests may still make use of the macro itself in end-to-end tests.
#if canImport(PersistMacros)
import PersistMacros

let testMacros: [String: Macro.Type] = [
"Persist": Persist_Storage_NoTransformer.self,
]
#endif

final class PersistMacroTests: XCTestCase {
func testMacro() throws {
func testUserDefaultsTransformer() throws {
#if canImport(PersistMacros)
assertMacroExpansion(
"""
struct Setting {
@Persist(key: "foo", storage: UserDefaultsStorage(.standard))
var testProperty: Int = 0
@Persist(
key: "transformed-key",
userDefaults: .standard,
transformer: JSONTransformer<TaskPriority>()
)
var transformedProperty: TaskPriority?
}
""",
expandedSource: """
struct Setting {
var testProperty: Int = 0 {
get {
testProperty_storage.getValue(forKey: "foo") ?? 0
}
nonmutating set {
testProperty_storage.setValue(newValue, forKey: "foo")
}
}
private let testProperty_storage = UserDefaultsStorage(.standard)
var transformedProperty: TaskPriority?
}
""",
macros: testMacros
macros: [
"Persist": Persist_UserDefaults_NoTransformer.self,
]
)
#else
throw XCTSkip("macros are only supported when running tests for the host platform")
Expand Down

0 comments on commit 69c4d6d

Please sign in to comment.