diff --git a/x/authz/keeper/keeper.go b/x/authz/keeper/keeper.go index 18b0be5dc497..e2db0db44be9 100644 --- a/x/authz/keeper/keeper.go +++ b/x/authz/keeper/keeper.go @@ -258,19 +258,21 @@ func (k Keeper) DeleteGrant(ctx context.Context, grantee, granter sdk.AccAddress // DeleteAllGrants revokes all authorizations granted to the grantee by the granter. func (k Keeper) DeleteAllGrants(ctx context.Context, granter sdk.AccAddress) error { - store := runtime.KVStoreAdapter(k.KVStoreService.OpenKVStore(ctx)) - - granterStoreKey := granterStoreKey(granter) - iterator := storetypes.KVStorePrefixIterator(store, granterStoreKey) - defer iterator.Close() - - for ; iterator.Valid(); iterator.Next() { - // Directly delete the grant without deserializing, as we're deleting all grants for the msgType. - // TODO: need to test this function - store.Delete(iterator.Key()) + count := 0 + err := k.IterateGranterGrants(ctx, granter, func(grantee sdk.AccAddress, msgType string) (stop bool, err error) { + count++ + if err := k.DeleteGrant(ctx, grantee, granter, msgType); err != nil { + return false, err + } + return false, nil + }) + if err != nil { + return err } - - return k.EventService.EventManager((ctx)).Emit(&authz.EventRevokeAll{ + if count == 0 { + return errorsmod.Wrapf(authz.ErrNoAuthorizationFound, "no grants found for granter %s", granter) + } + return k.EventService.EventManager(ctx).Emit(&authz.EventRevokeAll{ Granter: granter.String(), }) } @@ -344,6 +346,29 @@ func (k Keeper) IterateGrants(ctx context.Context, return nil } +func (k Keeper) IterateGranterGrants(ctx context.Context, granter sdk.AccAddress, + handler func(granteeAddr sdk.AccAddress, msgType string) (stop bool, err error), +) error { + // no-op if handler is nil + if handler == nil { + return nil + } + store := runtime.KVStoreAdapter(k.KVStoreService.OpenKVStore(ctx)) + iter := storetypes.KVStorePrefixIterator(store, granterStoreKey(granter)) + defer iter.Close() + for ; iter.Valid(); iter.Next() { + _, granteeAddr, msgType := parseGrantStoreKey(iter.Key()) + ok, err := handler(granteeAddr, msgType) + if err != nil { + return err + } + if ok { + break + } + } + return nil +} + func (k Keeper) getGrantQueueItem(ctx context.Context, expiration time.Time, granter, grantee sdk.AccAddress) (*authz.GrantQueueItem, error) { store := k.KVStoreService.OpenKVStore(ctx) bz, err := store.Get(GrantQueueKey(expiration, granter, grantee)) diff --git a/x/authz/keeper/keeper_test.go b/x/authz/keeper/keeper_test.go index 23d25f2af98b..5fdb1e1b96e5 100644 --- a/x/authz/keeper/keeper_test.go +++ b/x/authz/keeper/keeper_test.go @@ -92,29 +92,59 @@ func (s *TestSuite) TestKeeper() { require := s.Require() granterAddr := addrs[0] - granteeAddr := addrs[1] + grantee1Addr := addrs[1] + grantee2Addr := addrs[2] + grantee3Addr := addrs[3] + grantees := []sdk.AccAddress{grantee1Addr, grantee2Addr, grantee3Addr} s.T().Log("verify that no authorization returns nil") authorizations, err := s.authzKeeper.GetAuthorizations(ctx, granteeAddr, granterAddr) require.NoError(err) require.Len(authorizations, 0) - s.T().Log("verify save, get and delete") + s.T().Log("verify save, get and delete work for grants") sendAutz := &banktypes.SendAuthorization{SpendLimit: coins100} expire := now.AddDate(1, 0, 0) - err = s.authzKeeper.SaveGrant(ctx, granteeAddr, granterAddr, sendAutz, &expire) + for _, grantee := range grantees { + err = s.authzKeeper.SaveGrant(ctx, grantee, granterAddr, sendAutz, &expire) + require.NoError(err) + } + + for _, grantee := range grantees { + authorizations, err = s.authzKeeper.GetAuthorizations(ctx, grantee, granterAddr) + require.NoError(err) + require.Len(authorizations, 1) + } + + err = s.authzKeeper.DeleteGrant(ctx, grantee1Addr, granterAddr, sendAutz.MsgTypeURL()) require.NoError(err) - authorizations, err = s.authzKeeper.GetAuthorizations(ctx, granteeAddr, granterAddr) + authorizations, err = s.authzKeeper.GetAuthorizations(ctx, grantee1Addr, granterAddr) + require.NoError(err) + require.Len(authorizations, 0) + authorizations, err = s.authzKeeper.GetAuthorizations(ctx, grantee2Addr, granterAddr) + require.NoError(err) + require.Len(authorizations, 1) + authorizations, err = s.authzKeeper.GetAuthorizations(ctx, grantee3Addr, granterAddr) require.NoError(err) require.Len(authorizations, 1) - err = s.authzKeeper.DeleteGrant(ctx, granteeAddr, granterAddr, sendAutz.MsgTypeURL()) + err = s.authzKeeper.DeleteAllGrants(ctx, granterAddr) require.NoError(err) - authorizations, err = s.authzKeeper.GetAuthorizations(ctx, granteeAddr, granterAddr) + authorizations, err = s.authzKeeper.GetAuthorizations(ctx, grantee1Addr, granterAddr) require.NoError(err) require.Len(authorizations, 0) + authorizations, err = s.authzKeeper.GetAuthorizations(ctx, grantee2Addr, granterAddr) + require.NoError(err) + require.Len(authorizations, 0) + authorizations, err = s.authzKeeper.GetAuthorizations(ctx, grantee3Addr, granterAddr) + require.NoError(err) + require.Len(authorizations, 0) + + // test delete all grants for granter with no grants, should error + err = s.authzKeeper.DeleteAllGrants(ctx, granterAddr) + require.Error(err) s.T().Log("verify granting same authorization overwrite existing authorization") err = s.authzKeeper.SaveGrant(ctx, granteeAddr, granterAddr, sendAutz, &expire) @@ -162,6 +192,39 @@ func (s *TestSuite) TestKeeperIter() { }) } +func (s *TestSuite) TestKeeperGranterGrantsIter() { + ctx, addrs := s.ctx, s.addrs + + granterAddr := addrs[0] + granter2Addr := addrs[1] + granteeAddr := addrs[2] + grantee2Addr := addrs[3] + grantee3Addr := addrs[4] + e := ctx.HeaderInfo().Time.AddDate(1, 0, 0) + sendAuthz := banktypes.NewSendAuthorization(coins100, nil, s.accountKeeper.AddressCodec()) + + err := s.authzKeeper.SaveGrant(ctx, granteeAddr, granterAddr, sendAuthz, &e) + s.Require().NoError(err) + + err = s.authzKeeper.SaveGrant(ctx, grantee2Addr, granterAddr, sendAuthz, &e) + s.Require().NoError(err) + + err = s.authzKeeper.SaveGrant(ctx, grantee3Addr, granter2Addr, sendAuthz, &e) + s.Require().NoError(err) + + _ = s.authzKeeper.IterateGranterGrants(ctx, granterAddr, func(grantee sdk.AccAddress, msgType string) (bool, error) { + s.Require().Contains([]sdk.AccAddress{granteeAddr, grantee2Addr}, grantee) + s.Require().NotContains([]sdk.AccAddress{grantee3Addr}, grantee) + return true, nil + }) + + _ = s.authzKeeper.IterateGranterGrants(ctx, granter2Addr, func(grantee sdk.AccAddress, msgType string) (bool, error) { + s.Require().Equal(grantee3Addr, grantee) + s.Require().NotContains([]sdk.AccAddress{granteeAddr, grantee2Addr}, grantee) + return true, nil + }) +} + func (s *TestSuite) TestDispatchAction() { addrs := s.addrs require := s.Require() diff --git a/x/authz/keeper/msg_server_test.go b/x/authz/keeper/msg_server_test.go index 5e4a25279b87..6313f8a3f8fc 100644 --- a/x/authz/keeper/msg_server_test.go +++ b/x/authz/keeper/msg_server_test.go @@ -435,3 +435,67 @@ func (suite *TestSuite) TestPruneExpiredGrants() { }) suite.Require().Equal(0, totalGrants) } + +func (suite *TestSuite) TestRevokeAllGrants() { + addrs := simtestutil.CreateIncrementalAccounts(3) + + grantee, grantee2, granter := addrs[0], addrs[1], addrs[2] + granterStrAddr, err := suite.accountKeeper.AddressCodec().BytesToString(granter) + suite.Require().NoError(err) + + testCases := []struct { + name string + malleate func() *authz.MsgRevokeAll + expErr bool + errMsg string + }{ + { + name: "invalid granter", + malleate: func() *authz.MsgRevokeAll { + return &authz.MsgRevokeAll{ + Granter: "invalid", + } + }, + expErr: true, + errMsg: "invalid bech32 string", + }, + { + name: "no existing grant to revoke", + malleate: func() *authz.MsgRevokeAll { + return &authz.MsgRevokeAll{ + Granter: granterStrAddr, + } + }, + expErr: true, + errMsg: "authorization not found", + }, + { + name: "valid grant", + malleate: func() *authz.MsgRevokeAll { + suite.createSendAuthorization(grantee, granter) + suite.createSendAuthorization(grantee2, granter) + return &authz.MsgRevokeAll{ + Granter: granterStrAddr, + } + }, + }, + } + + for _, tc := range testCases { + suite.Run(tc.name, func() { + _, err := suite.msgSrvr.RevokeAll(suite.ctx, tc.malleate()) + if tc.expErr { + suite.Require().Error(err) + suite.Require().Contains(err.Error(), tc.errMsg) + } else { + suite.Require().NoError(err) + totalGrants := 0 + _ = suite.authzKeeper.IterateGranterGrants(suite.ctx, granter, func(sdk.AccAddress, string) (bool, error) { + totalGrants++ + return false, nil + }) + suite.Require().Equal(0, totalGrants) + } + }) + } +}