Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Added simulateTransaction endpoint to the rpc provider #78

Merged
merged 7 commits into from
Jul 5, 2023
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
10 changes: 10 additions & 0 deletions Sources/Starknet/Data/Events.swift
Original file line number Diff line number Diff line change
Expand Up @@ -56,3 +56,13 @@ public struct StarknetEvent: Decodable, Equatable {
case data
}
}

public struct StarknetEventContent: Decodable, Equatable {
public let keys: [Felt]
public let data: [Felt]

enum CodingKeys: String, CodingKey {
case keys
case data
}
}
167 changes: 167 additions & 0 deletions Sources/Starknet/Data/Transaction/TransactionTrace.swift
Original file line number Diff line number Diff line change
@@ -0,0 +1,167 @@
import Foundation

public enum StarknetEntryPointType: String, Decodable {
case external = "EXTERNAL"
case l1Handler = "L1_HANDLER"
case constructor = "CONSTRUCTOR"
}

public enum StarknetCallType: String, Decodable {
case call = "CALL"
case libraryCall = "LIBRARY_CALL"
}

public enum StarknetSimulationFlag: String, Codable {
case skipValidate = "SKIP_VALIDATE"
case skipExecute = "SKIP_EXECUTE"
}

public struct StarknetFunctionInvocation: Decodable, Equatable {
public let contractAddress: Felt
public let entrypoint: Felt
public let calldata: StarknetCalldata
public let callerAddress: Felt?
public let classHash: Felt?
public let entryPointType: StarknetEntryPointType?
public let callType: StarknetCallType?
public let result: [Felt]?
public let calls: [StarknetFunctionInvocation]?
public let events: [StarknetEventContent]?
public let messages: [MessageToL1]?

private enum CodingKeys: String, CodingKey {
case contractAddress = "contract_address"
case entrypoint = "entry_point_selector"
case calldata
case callerAddress = "caller_address"
case classHash = "class_hash"
case entryPointType = "entry_point_type"
case callType = "call_type"
case result
case calls
case events
case messages
}
}

public protocol StarknetTransactionTrace: Decodable, Equatable {}

public struct StarknetInvokeTransactionTrace: StarknetTransactionTrace {
public let validateInvocation: StarknetFunctionInvocation?
public let executeInvocation: StarknetFunctionInvocation?
public let feeTransferInvocation: StarknetFunctionInvocation?

private enum CodingKeys: String, CodingKey {
case validateInvocation = "validate_invocation"
case executeInvocation = "execute_invocation"
case feeTransferInvocation = "fee_transfer_invocation"
}
}

public struct StarknetDeployAccountTransactionTrace: StarknetTransactionTrace {
public let validateInvocation: StarknetFunctionInvocation?
public let constructorInvocation: StarknetFunctionInvocation?
public let feeTransferInvocation: StarknetFunctionInvocation?

private enum CodingKeys: String, CodingKey {
case validateInvocation = "validate_invocation"
case constructorInvocation = "constructor_invocation"
case feeTransferInvocation = "fee_transfer_invocation"
}
}

public struct StarknetL1HandlerTransactionTrace: StarknetTransactionTrace {
public let functionInvocation: StarknetFunctionInvocation?

private enum CodingKeys: String, CodingKey {
case functionInvocation = "function_invocation"
}
}

enum StarknetTransactionTraceWrapper: Decodable {
fileprivate enum Keys: String, CodingKey {
case validateInvocation = "validate_invocation"
case executeInvocation = "execute_invocation"
case feeTransferInvocation = "fee_transfer_invocation"
case constructorInvocation = "constructor_invocation"
case functionInvocation = "function_invocation"
}

case invoke(StarknetInvokeTransactionTrace)
case deployAccount(StarknetDeployAccountTransactionTrace)
case l1Handler(StarknetL1HandlerTransactionTrace)

public var transactionTrace: any StarknetTransactionTrace {
switch self {
case let .invoke(txTrace):
return txTrace
case let .deployAccount(txTrace):
return txTrace
case let .l1Handler(txTrace):
return txTrace
}
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: Keys.self)

// Invocations can be null, so `if let = try?` syntax won't work here.
do {
let validateInvocation = try container.decode(StarknetFunctionInvocation?.self, forKey: .validateInvocation)
let executeInvocation = try container.decode(StarknetFunctionInvocation?.self, forKey: .executeInvocation)
let feeTransferInvocation = try container.decode(StarknetFunctionInvocation?.self, forKey: .feeTransferInvocation)

self = .invoke(StarknetInvokeTransactionTrace(
validateInvocation: validateInvocation,
executeInvocation: executeInvocation,
feeTransferInvocation: feeTransferInvocation
))
return
} catch {}

do {
let validateInvocation = try container.decode(StarknetFunctionInvocation?.self, forKey: .validateInvocation)
let constructorInvocation = try container.decode(StarknetFunctionInvocation?.self, forKey: .constructorInvocation)
let feeTransferInvocation = try container.decode(StarknetFunctionInvocation?.self, forKey: .feeTransferInvocation)

self = .deployAccount(StarknetDeployAccountTransactionTrace(
validateInvocation: validateInvocation,
constructorInvocation: constructorInvocation,
feeTransferInvocation: feeTransferInvocation
))
return
} catch {}

do {
let functionInvocation = try container.decode(StarknetFunctionInvocation?.self, forKey: .functionInvocation)

self = .l1Handler(StarknetL1HandlerTransactionTrace(
functionInvocation: functionInvocation
))
return
} catch {}

throw DecodingError.dataCorrupted(
DecodingError.Context(
codingPath: container.codingPath,
debugDescription: "Invalid transaction trace"
))
}
}

public struct StarknetSimulatedTransaction: Decodable {
public let transactionTrace: any StarknetTransactionTrace
public let feeEstimation: StarknetFeeEstimate

enum CodingKeys: String, CodingKey {
case transactionTrace = "transaction_trace"
case feeEstimation = "fee_estimation"
}

public init(from decoder: Decoder) throws {
let container = try decoder.container(keyedBy: CodingKeys.self)

transactionTrace = try container.decode(StarknetTransactionTraceWrapper.self, forKey: .transactionTrace).transactionTrace
feeEstimation = try container.decode(StarknetFeeEstimate.self, forKey: .feeEstimation)
}
}
Original file line number Diff line number Diff line change
Expand Up @@ -13,4 +13,5 @@ enum JsonRpcMethod: String, Encodable {
case getTransactionByHash = "starknet_getTransactionByHash"
case getTransactionByBlockIdAndIndex = "starknet_getTransactionByBlockIdAndIndex"
case getTransactionReceipt = "starknet_getTransactionReceipt"
case simulateTransaction = "starknet_simulateTransaction"
}
40 changes: 31 additions & 9 deletions Sources/Starknet/Providers/StarknetProvider/JsonRpcParams.swift
Original file line number Diff line number Diff line change
Expand Up @@ -28,19 +28,19 @@ struct AddInvokeTransactionParams: Encodable {
}
}

// Workaround to allow encoding polymorphic array
struct WrappedSequencerTransaction: Encodable {
let transaction: any StarknetSequencerTransaction

func encode(to encoder: Encoder) throws {
try transaction.encode(to: encoder)
}
}

struct EstimateFeeParams: Encodable {
let request: [any StarknetSequencerTransaction]
let blockId: StarknetBlockId

// Walkaround to allow encoding polymorphic array
struct WrappedSequencerTransaction: Encodable {
let transaction: any StarknetSequencerTransaction

func encode(to encoder: Encoder) throws {
try transaction.encode(to: encoder)
}
}

func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)

Expand Down Expand Up @@ -104,4 +104,26 @@ struct GetTransactionReceiptPayload: Encodable {
}
}

struct SimulateTransactionsParams: Encodable {
let transactions: [any StarknetSequencerTransaction]
let blockId: StarknetBlockId
let simulationFlags: Set<StarknetSimulationFlag>

func encode(to encoder: Encoder) throws {
var container = encoder.container(keyedBy: CodingKeys.self)

let wrappedTransactions = transactions.map { WrappedSequencerTransaction(transaction: $0) }

try container.encode(wrappedTransactions, forKey: .transactions)
try container.encode(blockId, forKey: .blockId)
try container.encode(simulationFlags, forKey: .simulationFlags)
}

enum CodingKeys: String, CodingKey {
case transactions
case blockId = "block_id"
case simulationFlags = "simulation_flags"
}
}

struct EmptyParams: Encodable {}
Original file line number Diff line number Diff line change
Expand Up @@ -142,4 +142,12 @@ public class StarknetProvider: StarknetProviderProtocol {

return result.transactionReceipt
}

public func simulateTransactions(_ transactions: [any StarknetSequencerTransaction], at blockId: StarknetBlockId, simulationFlags: Set<StarknetSimulationFlag>) async throws -> [StarknetSimulatedTransaction] {
let params = SimulateTransactionsParams(transactions: transactions, blockId: blockId, simulationFlags: simulationFlags)

let result = try await makeRequest(method: .simulateTransaction, params: params, receive: [StarknetSimulatedTransaction].self)

return result
}
}
19 changes: 18 additions & 1 deletion Sources/Starknet/Providers/StarknetProviderProtocol.swift
Original file line number Diff line number Diff line change
Expand Up @@ -95,9 +95,17 @@ public protocol StarknetProviderProtocol {
/// - txHash : the hash of the requested transaction
/// - Returns: receipt of a transaction identified by given hash
func getTransactionReceiptBy(hash: Felt) async throws -> StarknetTransactionReceipt

/// Simulate running a given list of transactions, and generate the execution trace
///
/// - Parameters:
/// - transactions: list of transactions to simulate
/// - blockId: block used to run the simulation
/// - simulationFlags: a set of simulation flags
func simulateTransactions(_ transactions: [any StarknetSequencerTransaction], at blockId: StarknetBlockId, simulationFlags: Set<StarknetSimulationFlag>) async throws -> [StarknetSimulatedTransaction]
}

private let defaultBlockId = StarknetBlockId.tag(.pending)
let defaultBlockId = StarknetBlockId.tag(.pending)

public extension StarknetProviderProtocol {
/// Call starknet contract in the pending block.
Expand Down Expand Up @@ -157,4 +165,13 @@ public extension StarknetProviderProtocol {
func getClassHashAt(_ address: Felt) async throws -> Felt {
try await getClassHashAt(address, at: defaultBlockId)
}

/// Simulate running a given list of transactions in the latest block, and generate the execution trace
///
/// - Parameters:
/// - transactions: list of transactions to simulate
/// - simulationFlags: a set of simulation flags
func simulateTransactions(_ transactions: [any StarknetSequencerTransaction], simulationFlags: Set<StarknetSimulationFlag>) async throws -> [StarknetSimulatedTransaction] {
try await simulateTransactions(transactions, at: defaultBlockId, simulationFlags: simulationFlags)
}
}
57 changes: 57 additions & 0 deletions Tests/StarknetTests/Providers/ProviderTests.swift
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,7 @@ final class ProviderTests: XCTestCase {
To run, make sure you're running starknet-devnet on port 5050, with seed 0
*/
static var devnetClient: DevnetClientProtocol!

var provider: StarknetProviderProtocol!

override class func setUp() {
Expand Down Expand Up @@ -151,4 +152,60 @@ final class ProviderTests: XCTestCase {

XCTAssertEqual(fees.count, 2)
}

func testSimulateTransactions() async throws {
let acc = try await ProviderTests.devnetClient.deployAccount(name: "test_simulate_transactions")
let signer = StarkCurveSigner(privateKey: acc.details.privateKey)!
let contract = try await ProviderTests.devnetClient.deployContract(contractName: "balance", deprecated: true)
let account = StarknetAccount(address: acc.details.address, signer: signer, provider: provider)

let nonce = try await account.getNonce()

let call = StarknetCall(contractAddress: contract.address, entrypoint: starknetSelector(from: "increase_balance"), calldata: [1000])
let params = StarknetExecutionParams(nonce: nonce, maxFee: 1_000_000_000_000)

let invokeTx = try account.sign(calls: [call], params: params, forFeeEstimation: true)

let accountClassHash = try await provider.getClassHashAt(account.address)
let newSigner = StarkCurveSigner(privateKey: 1234)!
let newPublicKey = newSigner.publicKey
let newAccountAddress = StarknetContractAddressCalculator.calculateFrom(classHash: accountClassHash, calldata: [newPublicKey], salt: .zero)
let newAccount = StarknetAccount(address: newAccountAddress, signer: newSigner, provider: provider)

try await Self.devnetClient.prefundAccount(address: newAccountAddress)

let newAccountParams = StarknetExecutionParams(nonce: 0, maxFee: 0)
let deployAccountTx = try newAccount.signDeployAccount(classHash: accountClassHash, calldata: [newPublicKey], salt: .zero, params: newAccountParams, forFeeEstimation: true)

let simulations = try await provider.simulateTransactions([invokeTx, deployAccountTx], at: .tag(.latest), simulationFlags: [])

XCTAssertEqual(simulations.count, 2)
XCTAssertTrue(simulations[0].transactionTrace is StarknetInvokeTransactionTrace)
XCTAssertTrue(simulations[1].transactionTrace is StarknetDeployAccountTransactionTrace)

let invokeWithoutSignature = StarknetSequencerInvokeTransaction(
senderAddress: invokeTx.senderAddress,
calldata: invokeTx.calldata,
signature: [],
maxFee: invokeTx.maxFee,
nonce: invokeTx.nonce,
version: invokeTx.version
)

let deployAccountWithoutSignature = StarknetSequencerDeployAccountTransaction(
signature: [],
maxFee: deployAccountTx.maxFee,
nonce: deployAccountTx.nonce,
contractAddressSalt: deployAccountTx.contractAddressSalt,
constructorCalldata: deployAccountTx.constructorCalldata,
classHash: deployAccountTx.classHash,
version: deployAccountTx.version
)

let simulations2 = try await provider.simulateTransactions([invokeWithoutSignature, deployAccountWithoutSignature], at: .tag(.latest), simulationFlags: [.skipValidate])

XCTAssertEqual(simulations2.count, 2)
XCTAssertTrue(simulations2[0].transactionTrace is StarknetInvokeTransactionTrace)
XCTAssertTrue(simulations2[1].transactionTrace is StarknetDeployAccountTransactionTrace)
}
}
6 changes: 6 additions & 0 deletions Tests/StarknetTests/Utils/DevnetClient/DevnetClient.swift
Original file line number Diff line number Diff line change
Expand Up @@ -59,6 +59,8 @@ func makeDevnetClient() -> DevnetClientProtocol {
private let devnetPath: String
private let starknetPath: String

private var deployedContracts: [String: TransactionResult] = [:]

let gatewayUrl: String
let feederGatewayUrl: String
let rpcUrl: String
Expand Down Expand Up @@ -220,6 +222,10 @@ func makeDevnetClient() -> DevnetClientProtocol {
public func deployContract(contractName: String, deprecated: Bool) async throws -> TransactionResult {
try guardDevnetIsRunning()

if let transactionResult = deployedContracts["contractName"] {
return transactionResult
}

let classHash = try await declareContract(contractName: contractName, deprecated: deprecated)

let params = [
Expand Down
4 changes: 2 additions & 2 deletions requirements.txt
Original file line number Diff line number Diff line change
@@ -1,2 +1,2 @@
starknet-devnet==0.5.2
cairo-lang==0.11.1.1
starknet-devnet==0.5.5
cairo-lang==0.12.0