diff --git a/contracts/mocks/EnumerableMapMock.sol b/contracts/mocks/EnumerableMapMock.sol index 79226d739ec..28472b91787 100644 --- a/contracts/mocks/EnumerableMapMock.sol +++ b/contracts/mocks/EnumerableMapMock.sol @@ -34,7 +34,15 @@ contract EnumerableMapMock { } + function tryGet(uint256 key) public view returns (bool, address) { + return _map.tryGet(key); + } + function get(uint256 key) public view returns (address) { return _map.get(key); } + + function getWithMessage(uint256 key, string calldata errorMessage) public view returns (address) { + return _map.get(key, errorMessage); + } } diff --git a/contracts/utils/EnumerableMap.sol b/contracts/utils/EnumerableMap.sol index 2daa131905a..3eca02d0a7f 100644 --- a/contracts/utils/EnumerableMap.sol +++ b/contracts/utils/EnumerableMap.sol @@ -143,6 +143,16 @@ library EnumerableMap { return (entry._key, entry._value); } + /** + * @dev Tries to returns the value associated with `key`. O(1). + * Does not revert if `key` is not in the map. + */ + function _tryGet(Map storage map, bytes32 key) private view returns (bool, bytes32) { + uint256 keyIndex = map._indexes[key]; + if (keyIndex == 0) return (false, 0); // Equivalent to contains(map, key) + return (true, map._entries[keyIndex - 1]._value); // All indexes are 1-based + } + /** * @dev Returns the value associated with `key`. O(1). * @@ -151,7 +161,9 @@ library EnumerableMap { * - `key` must be in the map. */ function _get(Map storage map, bytes32 key) private view returns (bytes32) { - return _get(map, key, "EnumerableMap: nonexistent key"); + uint256 keyIndex = map._indexes[key]; + require(keyIndex != 0, "EnumerableMap: nonexistent key"); // Equivalent to contains(map, key) + return map._entries[keyIndex - 1]._value; // All indexes are 1-based } /** @@ -217,6 +229,15 @@ library EnumerableMap { return (uint256(key), address(uint160(uint256(value)))); } + /** + * @dev Tries to returns the value associated with `key`. O(1). + * Does not revert if `key` is not in the map. + */ + function tryGet(UintToAddressMap storage map, uint256 key) internal view returns (bool, address) { + (bool success, bytes32 value) = _tryGet(map._inner, bytes32(key)); + return (success, address(uint160(uint256(value)))); + } + /** * @dev Returns the value associated with `key`. O(1). * diff --git a/test/utils/EnumerableMap.test.js b/test/utils/EnumerableMap.test.js index fcbdba6d7fd..29052d0f3a7 100644 --- a/test/utils/EnumerableMap.test.js +++ b/test/utils/EnumerableMap.test.js @@ -1,4 +1,4 @@ -const { BN, expectEvent } = require('@openzeppelin/test-helpers'); +const { BN, constants, expectEvent, expectRevert } = require('@openzeppelin/test-helpers'); const { expect } = require('chai'); const zip = require('lodash.zip'); @@ -139,4 +139,44 @@ contract('EnumerableMap', function (accounts) { expect(await this.map.contains(keyB)).to.equal(false); }); }); + + describe('read', function () { + beforeEach(async function () { + await this.map.set(keyA, accountA); + }); + + describe('get', function () { + it('existing value', async function () { + expect(await this.map.get(keyA)).to.be.equal(accountA); + }); + it('missing value', async function () { + await expectRevert(this.map.get(keyB), "EnumerableMap: nonexistent key"); + }); + }); + + describe('get with message', function () { + it('existing value', async function () { + expect(await this.map.getWithMessage(keyA, "custom error string")).to.be.equal(accountA); + }); + it('missing value', async function () { + await expectRevert(this.map.getWithMessage(keyB, "custom error string"), "custom error string"); + }); + }); + + describe('tryGet', function () { + it('existing value', async function () { + expect(await this.map.tryGet(keyA)).to.be.deep.equal({ + '0': true, + '1': accountA + }); + }); + it('missing value', async function () { + expect(await this.map.tryGet(keyB)).to.be.deep.equal({ + '0': false, + '1': constants.ZERO_ADDRESS + }); + }); + }); + + }); });