From aa4b8a46839252499126bfb182f0a710f55c78f5 Mon Sep 17 00:00:00 2001 From: Rachit Sonthalia Date: Mon, 3 Jun 2024 15:52:20 +0530 Subject: [PATCH] refactored code --- zk/txpool/policy.go | 32 +++++++++++++++++++++++++++----- zk/txpool/pool.go | 44 ++++++++++++++++---------------------------- 2 files changed, 43 insertions(+), 33 deletions(-) diff --git a/zk/txpool/policy.go b/zk/txpool/policy.go index 5e0027a2a29..d0d50f97782 100644 --- a/zk/txpool/policy.go +++ b/zk/txpool/policy.go @@ -25,16 +25,38 @@ func containsPolicy(policies []byte, policy PolicyName) bool { } // create a method checkpolicy to check an address according to passed policy in the method -func (p *TxPool) checkPolicy(addr common.Address, policy PolicyName, mode string) (bool, error) { - var table string +func (p *TxPool) checkPolicy(addr common.Address, policy PolicyName) (bool, error) { + // Retrieve the mode configuration + var mode string + err := p.aclDB.View(context.TODO(), func(tx kv.Tx) error { + value, err := tx.GetOne(Config, []byte("mode")) + if err != nil { + return err + } + if value == nil || string(value) == "disabled" { + mode = "disabled" + return nil + } + + mode = string(value) + return nil + }) + if err != nil { + return false, err + } + + if mode == "disabled" { + return true, nil + } + + // Determine the appropriate table based on the mode + table := Blacklist if mode == "allowlist" { table = Whitelist - } else { - table = Blacklist } var policyBytes []byte - err := p.aclDB.View(context.TODO(), func(tx kv.Tx) error { + err = p.aclDB.View(context.TODO(), func(tx kv.Tx) error { value, err := tx.GetOne(table, addr.Bytes()) if err != nil { return err diff --git a/zk/txpool/pool.go b/zk/txpool/pool.go index a35938af377..45636886eaa 100644 --- a/zk/txpool/pool.go +++ b/zk/txpool/pool.go @@ -713,37 +713,25 @@ func (p *TxPool) validateTx(txn *types.TxSlot, isLocal bool, stateCache kvcache. return InsufficientFunds } - var mode string - p.aclDB.View(context.TODO(), func(tx kv.Tx) error { - value, err := tx.GetOne(Config, []byte("mode")) + switch resolvePolicy(txn) { + case SendTx: + var allow bool + allow, err := p.checkPolicy(from, SendTx) if err != nil { panic(err) } - mode = string(value) - return nil - }) - - if mode != "disabled" { - switch resolvePolicy(txn) { - case SendTx: - var allow bool - allow, err := p.checkPolicy(from, SendTx, mode) - if err != nil { - panic(err) - } - if !allow { - return SenderDisallowedSendTx - } - case Deploy: - var allow bool - // check that sender may deploy contracts - allow, err := p.checkPolicy(from, Deploy, mode) - if err != nil { - panic(err) - } - if !allow { - return SenderDisallowedDeploy - } + if !allow { + return SenderDisallowedSendTx + } + case Deploy: + var allow bool + // check that sender may deploy contracts + allow, err := p.checkPolicy(from, Deploy) + if err != nil { + panic(err) + } + if !allow { + return SenderDisallowedDeploy } }