From e7343e841a34917e74332cacc3e0e313e1387f33 Mon Sep 17 00:00:00 2001 From: David Gregory Date: Tue, 24 May 2022 12:54:18 +0100 Subject: [PATCH] Implement a concat operation on BitMapNode to merge the CHAMP structures --- .../scala-2.12/cats/data/HashMapCompat.scala | 10 +- .../scala-2.13+/cats/data/HashMapCompat.scala | 10 +- core/src/main/scala/cats/data/HashMap.scala | 156 +++++++++++++++++- 3 files changed, 156 insertions(+), 20 deletions(-) diff --git a/core/src/main/scala-2.12/cats/data/HashMapCompat.scala b/core/src/main/scala-2.12/cats/data/HashMapCompat.scala index 9acef91791..2696230865 100644 --- a/core/src/main/scala-2.12/cats/data/HashMapCompat.scala +++ b/core/src/main/scala-2.12/cats/data/HashMapCompat.scala @@ -33,7 +33,7 @@ private[data] trait HashMapCompat[K, +V] { self: HashMap[K, V] => */ final def concat[VV >: V](traversable: TraversableOnce[(K, VV)]): HashMap[K, VV] = { val newRootNode = traversable.foldLeft(self.rootNode: HashMap.Node[K, VV]) { case (node, (k, v)) => - node.updated(k, improve(self.hashKey.hash(k)), v, replaceExisting = true, 0) + node.updated(k, improve(self.hashKey.hash(k)), v, replaceExisting = true, depth = 0) } if (newRootNode eq self.rootNode) @@ -50,13 +50,9 @@ private[data] trait HashMapCompat[K, +V] { self: HashMap[K, V] => */ final def concat[VV >: V](hm: HashMap[K, VV]): HashMap[K, VV] = { val newRootNode = if (self.size <= hm.size) { - self.iterator.foldLeft(hm.rootNode) { case (node, (k, v)) => - node.updated(k, improve(self.hashKey.hash(k)), v, replaceExisting = false, 0) - } + hm.rootNode.concat(self.rootNode, replaceExisting = false, depth = 0) } else { - hm.iterator.foldLeft(self.rootNode: HashMap.Node[K, VV]) { case (node, (k, v)) => - node.updated(k, improve(self.hashKey.hash(k)), v, replaceExisting = true, 0) - } + self.rootNode.concat(hm.rootNode, replaceExisting = true, depth = 0) } if (newRootNode eq self.rootNode) diff --git a/core/src/main/scala-2.13+/cats/data/HashMapCompat.scala b/core/src/main/scala-2.13+/cats/data/HashMapCompat.scala index 854fe4beb2..ce372861d4 100644 --- a/core/src/main/scala-2.13+/cats/data/HashMapCompat.scala +++ b/core/src/main/scala-2.13+/cats/data/HashMapCompat.scala @@ -35,13 +35,15 @@ private[data] trait HashMapCompat[K, +V] extends IterableOnce[(K, V)] { self: Ha */ final def concat[VV >: V](iterable: IterableOnce[(K, VV)]): HashMap[K, VV] = { val newRootNode = iterable match { - case hm: HashMap[K, V] @unchecked if self.size <= hm.size => - self.iterator.foldLeft(hm.rootNode: HashMap.Node[K, VV]) { case (node, (k, v)) => - node.updated(k, improve(self.hashKey.hash(k)), v, replaceExisting = false, 0) + case hm: HashMap[K, V] @unchecked => + if (self.size <= hm.size) { + hm.rootNode.concat(self.rootNode, replaceExisting = false, depth = 0) + } else { + self.rootNode.concat(hm.rootNode, replaceExisting = true, depth = 0) } case _ => iterable.iterator.foldLeft(self.rootNode: HashMap.Node[K, VV]) { case (node, (k, v)) => - node.updated(k, improve(self.hashKey.hash(k)), v, replaceExisting = true, 0) + node.updated(k, improve(self.hashKey.hash(k)), v, replaceExisting = true, depth = 0) } } diff --git a/core/src/main/scala/cats/data/HashMap.scala b/core/src/main/scala/cats/data/HashMap.scala index 43e913501d..7c7a456753 100644 --- a/core/src/main/scala/cats/data/HashMap.scala +++ b/core/src/main/scala/cats/data/HashMap.scala @@ -394,7 +394,7 @@ object HashMap extends HashMapInstances with HashMapCompatCompanion { def get(key: K, keyHash: Int, depth: Int): Option[V] /** - * The current trie node updated to add the provided key-value pair. + * Return the current trie node updated to include the provided key-value pair. * * @param newKey the key to add. * @param newKeyHash the hash of the key to add. @@ -405,6 +405,16 @@ object HashMap extends HashMapInstances with HashMapCompatCompanion { */ def updated[VV >: V](newKey: K, newKeyHash: Int, value: VV, replaceExisting: Boolean, depth: Int): Node[K, VV] + /** + * Return the current trie node updated to include all of the key-value pairs of `that`. + * + * @param that the trie node to concatenate with this one. + * @param replaceExisting whether to replace existing values with those from `that` if a matching key already exists. + * @param depth the 0-indexed depth in the trie structure. + * @return a new [[HashMap.Node]] containing all elements of this trie node and `that`. + */ + def concat[VV >: V](that: Node[K, VV], replaceExisting: Boolean, depth: Int): Node[K, VV] + /** * The current trie node updated to remove the provided key. * @@ -565,6 +575,31 @@ object HashMap extends HashMapInstances with HashMapCompatCompanion { } } + final def concat[VV >: V](that: Node[K, VV], replaceExisting: Boolean, depth: Int): Node[K, VV] = + that match { + case that: CollisionNode[K, VV] @unchecked if this.collisionHash == that.collisionHash => + val builder = Vector.newBuilder[(K, VV)] + builder.sizeHint(this.size + that.size) + if (replaceExisting) { + builder ++= that.contents.toVector + this.contents.toVector.foreach { case kv @ (key, _) => + if (!that.contents.exists { case (k, _) => hashKey.eqv(key, k) }) + builder += kv + } + } else { + builder ++= this.contents.toVector + that.contents.toVector.foreach { case kv @ (key, _) => + if (!contents.exists { case (k, _) => hashKey.eqv(key, k) }) + builder += kv + } + } + new CollisionNode[K, VV](collisionHash, NonEmptyVector.fromVectorUnsafe(builder.result())) + case _: CollisionNode[_, _] => + throw new IllegalStateException("Attempting to merge collision nodes with mismatched collision hashes") + case _: BitMapNode[_, _] => + throw new IllegalStateException("Attempting to merge a collision node with a bitmap node") + } + final def ===[VV >: V](that: Node[K, VV])(implicit eqValue: Eq[VV]): Boolean = { (this eq that) || { that match { @@ -743,7 +778,7 @@ object HashMap extends HashMapInstances with HashMapCompatCompanion { rightHash: Int, rightValue: VV, depth: Int - ): Node[K, VV] = { + ): BitMapNode[K, VV] = { val newNode = mergeValues(left, leftHash, leftValue, right, rightHash, rightValue, depth) val valueIndex = Node.StrideLength * Node.indexFrom(keyValueMap, bitPos) val nodeIndex = contents.length - Node.StrideLength - Node.indexFrom(nodeMap, bitPos) @@ -765,7 +800,11 @@ object HashMap extends HashMapInstances with HashMapCompatCompanion { new BitMapNode[K, V](keyValueMap ^ bitPos, nodeMap | bitPos, newContents, size + 1) } - final private def replaceNode[VV >: V](index: Int, oldNode: Node[K, VV], newNode: Node[K, VV]): Node[K, VV] = { + final private def replaceNode[VV >: V]( + index: Int, + oldNode: Node[K, VV], + newNode: Node[K, VV] + ): BitMapNode[K, VV] = { val targetIndex = contents.length - 1 - index val newContents = new Array[Any](contents.length) System.arraycopy(contents, 0, newContents, 0, contents.length) @@ -773,6 +812,15 @@ object HashMap extends HashMapInstances with HashMapCompatCompanion { new BitMapNode[K, V](keyValueMap, nodeMap, newContents, size + (newNode.size - oldNode.size)) } + final private def appendNode[VV >: V](bitPos: Int, newNode: Node[K, VV]): BitMapNode[K, VV] = { + val newContents = new Array[Any](contents.length + 1) + val nodeIndex = newContents.length - 1 - Node.indexFrom(nodeMap, bitPos) + System.arraycopy(contents, 0, newContents, 0, nodeIndex) + newContents(nodeIndex) = newNode + System.arraycopy(contents, nodeIndex, newContents, nodeIndex + 1, newContents.length - 1 - nodeIndex) + new BitMapNode[K, V](keyValueMap, nodeMap | bitPos, newContents, size + newNode.size) + } + final private def updateNode[VV >: V]( bitPos: Int, newKey: K, @@ -780,7 +828,7 @@ object HashMap extends HashMapInstances with HashMapCompatCompanion { newValue: VV, replaceExisting: Boolean, depth: Int - ): Node[K, VV] = { + ): BitMapNode[K, VV] = { val index = Node.indexFrom(nodeMap, bitPos) val subNode = getNode(index) val newSubNode = subNode.updated(newKey, newKeyHash, newValue, replaceExisting, depth + 1) @@ -791,7 +839,7 @@ object HashMap extends HashMapInstances with HashMapCompatCompanion { replaceNode(index, subNode, newSubNode) } - final private def replaceValueAtIndex[VV >: V](index: Int, newValue: VV): Node[K, VV] = { + final private def replaceValueAtIndex[VV >: V](index: Int, newValue: VV): BitMapNode[K, VV] = { val valueIndex = Node.StrideLength * index + 1 val newContents = new Array[Any](contents.length) System.arraycopy(contents, 0, newContents, 0, contents.length) @@ -806,7 +854,7 @@ object HashMap extends HashMapInstances with HashMapCompatCompanion { newValue: VV, replaceExisting: Boolean, depth: Int - ): Node[K, VV] = { + ): BitMapNode[K, VV] = { val index = Node.indexFrom(keyValueMap, bitPos) val existingKey = getKey(index) val existingValue = getValue(index) @@ -829,7 +877,7 @@ object HashMap extends HashMapInstances with HashMapCompatCompanion { ) } - final private def appendKeyValue[VV >: V](bitPos: Int, newKey: K, newValue: VV): Node[K, VV] = { + final private def appendKeyValue[VV >: V](bitPos: Int, newKey: K, newValue: VV): BitMapNode[K, VV] = { val index = Node.StrideLength * Node.indexFrom(keyValueMap, bitPos) val newContents = new Array[Any](contents.length + Node.StrideLength) System.arraycopy(contents, 0, newContents, 0, index) @@ -858,13 +906,13 @@ object HashMap extends HashMapInstances with HashMapCompatCompanion { } } - final private def removeKeyValue(bitPos: Int, removeKey: K, removeKeyHash: Int, depth: Int): Node[K, V] = { + final private def removeKeyValue(bitPos: Int, removeKey: K, removeKeyHash: Int, depth: Int): BitMapNode[K, V] = { val index = Node.indexFrom(keyValueMap, bitPos) val existingKey = getKey(index) if (!hashKey.eqv(existingKey, removeKey)) { this } else if (allElementsCount == 1) { - Node.empty[K, V] + new BitMapNode[K, V](0, 0, Array.empty[Any], 0) } else { val keyIndex = Node.StrideLength * index val newContents = new Array[Any](contents.length - Node.StrideLength) @@ -958,6 +1006,96 @@ object HashMap extends HashMapInstances with HashMapCompatCompanion { } } + final def concat[VV >: V](that: Node[K, VV], replaceExisting: Boolean, depth: Int): Node[K, VV] = that match { + case that: BitMapNode[K, VV] @unchecked => + var newNode: BitMapNode[K, VV] = this + + var index = 0 + var thisNodeMap = newNode.nodeMap + var thatNodeMap = that.nodeMap + val maxNewNodeIndex = Integer.numberOfTrailingZeros(Integer.highestOneBit(thatNodeMap)) + while (index <= maxNewNodeIndex) { + val thisHasNode = (thisNodeMap & 1) == 1 + val thatHasNode = (thatNodeMap & 1) == 1 + if (thisHasNode && thatHasNode) { + // Merge `this` and `that` nodes + val bitPos = Node.bitPosFrom(index) + val thisIndex = Node.indexFrom(newNode.nodeMap, bitPos) + val thisNode = newNode.getNode(thisIndex) + val thatNode = that.getNode(Node.indexFrom(that.nodeMap, bitPos)) + val mergedNode = thisNode.concat(thatNode, replaceExisting, depth + 1) + newNode = newNode.replaceNode(thisIndex, thisNode, mergedNode) + } else if (thatHasNode) { + // Copy node from `that` + val bitPos = Node.bitPosFrom(index) + val thatNodeIndex = Node.indexFrom(that.nodeMap, bitPos) + val thatNode = that.getNode(thatNodeIndex) + newNode = newNode.appendNode(bitPos, thatNode) + } + + thisNodeMap >>= 1 + thatNodeMap >>= 1 + index += 1 + } + + index = 0 + var thisKeyValueMap = this.keyValueMap + var thatKeyValueMap = that.keyValueMap + val bothKeyValueMap = thisKeyValueMap | thatKeyValueMap + val maxKeyValueIndex = Integer.numberOfTrailingZeros(Integer.highestOneBit(bothKeyValueMap)) + while (index <= maxKeyValueIndex) { + val thisHasKeyValue = (thisKeyValueMap & 1) == 1 + val thatHasKeyValue = (thatKeyValueMap & 1) == 1 + if (thisHasKeyValue && thatHasKeyValue) { + // Merge `this` and `that` key value pair + val bitPos = Node.bitPosFrom(index) + val thatIndex = Node.indexFrom(that.keyValueMap, bitPos) + val thatKey = that.getKey(thatIndex) + val thatValue = that.getValue(thatIndex) + newNode = newNode.updateKeyValue( + bitPos, + thatKey, + improve(hashKey.hash(thatKey)), + thatValue, + replaceExisting, + depth + ) + } else if (thisHasKeyValue) { + val bitPos = Node.bitPosFrom(index) + if (newNode.hasNodeAt(bitPos)) { + // Move `this` key-value pair into `that`'s node for this hash + val valueIndex = Node.indexFrom(this.keyValueMap, bitPos) + val thisKey = this.getKey(valueIndex) + val thisKeyHash = improve(hashKey.hash(thisKey)) + val thisValue = this.getValue(valueIndex) + newNode = newNode.removeKeyValue(bitPos, thisKey, thisKeyHash, depth) + newNode = newNode.updateNode(bitPos, thisKey, thisKeyHash, thisValue, !replaceExisting, depth) + } + } else if (thatHasKeyValue) { + // Move `that` key value pair into `this` + val bitPos = Node.bitPosFrom(index) + val thatIndex = Node.indexFrom(that.keyValueMap, bitPos) + val thatKey = that.getKey(thatIndex) + val thatKeyHash = improve(hashKey.hash(thatKey)) + val thatValue = that.getValue(thatIndex) + if (newNode.hasNodeAt(bitPos)) { + newNode = newNode.updateNode(bitPos, thatKey, thatKeyHash, thatValue, replaceExisting, depth) + } else { + newNode = newNode.appendKeyValue(bitPos, thatKey, thatValue) + } + } + + thisKeyValueMap >>= 1 + thatKeyValueMap >>= 1 + index += 1 + } + + newNode + + case _: CollisionNode[_, _] => + throw new IllegalStateException("Attempting to merge a bitmap node with a collision node") + } + final override def ===[VV >: V](that: Node[K, VV])(implicit eqValue: Eq[VV]): Boolean = { (this eq that) || { that match {