From 6cecbc82e84db49516dca5d6ec41441588c2303d Mon Sep 17 00:00:00 2001 From: Jannis Date: Tue, 6 Apr 2021 11:19:54 +0200 Subject: [PATCH] Traverse fixes (#2365) * FoldRight order * Test and fix traverse implementations * Undo foldRight changes Co-authored-by: Simon Vergauwen --- .../src/main/kotlin/arrow/core/Iterable.kt | 21 ++++---- .../main/kotlin/arrow/core/NonEmptyList.kt | 28 +++++----- .../src/main/kotlin/arrow/core/Sequence.kt | 34 ++++++++---- .../src/main/kotlin/arrow/core/map.kt | 32 +++++++----- .../test/kotlin/arrow/core/IterableTest.kt | 34 +++++++++--- .../src/test/kotlin/arrow/core/MapKTest.kt | 51 ++++++++++++++++++ .../kotlin/arrow/core/NonEmptyListTest.kt | 52 +++++++++++++++++++ .../test/kotlin/arrow/core/SequenceKTest.kt | 42 +++++++++++++-- 8 files changed, 238 insertions(+), 56 deletions(-) diff --git a/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/Iterable.kt b/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/Iterable.kt index 3804f9ee1df..8da56d0f1a1 100644 --- a/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/Iterable.kt +++ b/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/Iterable.kt @@ -287,13 +287,16 @@ inline fun Iterable.foldRight(initial: B, operation: (A, acc: B) -> B) else -> reversed().fold(initial) { acc, a -> operation(a, acc) } } -inline fun Iterable.traverseEither(f: (A) -> Either): Either> = - foldRight>>(emptyList().right()) { a, acc -> +inline fun Iterable.traverseEither(f: (A) -> Either): Either> { + val acc = mutableListOf() + forEach { a -> when (val res = f(a)) { - is Right -> acc.map { bs -> listOf(res.value) + bs } - is Left -> res + is Right -> acc.add(res.value) + is Left -> return@traverseEither res } } + return acc.right() +} fun Iterable>.sequenceEither(): Either> = traverseEither(::identity) @@ -302,15 +305,15 @@ inline fun Iterable.traverseValidated( semigroup: Semigroup, f: (A) -> Validated ): Validated> = semigroup.run { - foldRight>>(emptyList().valid()) { a, acc -> + fold(Valid(mutableListOf()) as Validated>) { acc, a -> when (val res = f(a)) { is Validated.Valid -> when (acc) { - is Validated.Valid -> acc.map { bs -> listOf(res.value) + bs } - is Validated.Invalid -> acc + is Valid -> acc.also { it.value.add(res.value) } + is Invalid -> acc } is Validated.Invalid -> when (acc) { - is Validated.Valid -> res - is Validated.Invalid -> Invalid(res.value.combine(acc.value)) + is Valid -> res + is Invalid -> acc.value.combine(res.value).invalid() } } } diff --git a/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/NonEmptyList.kt b/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/NonEmptyList.kt index 2a039a3502b..e54e14853e8 100644 --- a/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/NonEmptyList.kt +++ b/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/NonEmptyList.kt @@ -412,13 +412,17 @@ fun NonEmptyList.unzip(f: (C) -> Pair): Pair, } } -inline fun NonEmptyList.traverseEither(f: (A) -> Either): Either> = - foldRight(f(head).map(::nonEmptyListOf)) { a, acc -> +inline fun NonEmptyList.traverseEither(f: (A) -> Either): Either> { + val acc = mutableListOf() + forEach { a -> when (val res = f(a)) { - is Right -> acc.map { bs -> nonEmptyListOf(res.value) + bs } - is Left -> res + is Right -> acc.add(res.value) + is Left -> return@traverseEither res } } + // Safe due to traverse laws + return NonEmptyList.fromListUnsafe(acc).right() +} fun NonEmptyList>.sequenceEither(): Either> = traverseEither(::identity) @@ -427,18 +431,18 @@ inline fun NonEmptyList.traverseValidated( semigroup: Semigroup, f: (A) -> Validated ): Validated> = - foldRight(f(head).map(::nonEmptyListOf)) { a, acc -> + fold(mutableListOf().valid() as Validated>) { acc, a -> when (val res = f(a)) { - is Validated.Valid -> when (acc) { - is Validated.Valid -> acc.map { bs -> nonEmptyListOf(res.value) + bs } - is Validated.Invalid -> acc + is Valid -> when (acc) { + is Valid -> acc.also { it.value.add(res.value) } + is Invalid -> acc } - is Validated.Invalid -> when (acc) { - is Validated.Valid -> res - is Validated.Invalid -> Invalid(semigroup.run { res.value.combine(acc.value) }) + is Invalid -> when (acc) { + is Valid -> res + is Invalid -> semigroup.run { acc.value.combine(res.value).invalid() } } } - } + }.map { NonEmptyList.fromListUnsafe(it) } fun NonEmptyList>.sequenceValidated(semigroup: Semigroup): Validated> = traverseValidated(semigroup, ::identity) diff --git a/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/Sequence.kt b/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/Sequence.kt index 8d4e6ee8468..4a8eae12f60 100644 --- a/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/Sequence.kt +++ b/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/Sequence.kt @@ -643,23 +643,35 @@ fun Sequence.split(): Pair, A>? = fun Sequence.tail(): Sequence = drop(1) -fun Sequence.traverseEither(f: (A) -> Either): Either> = - foldRight>>(Eval.now(sequenceOf().right())) { a, eval -> +fun Sequence.traverseEither(f: (A) -> Either): Either> { + // Note: Using a mutable list here avoids the stackoverflows one can accidentally create when using + // Sequence.plus instead. But we don't convert the sequence to a list beforehand to avoid + // forcing too much of the sequence to be evaluated. + val acc = mutableListOf() + forEach { a -> when (val res = f(a)) { - is Right -> eval.map { either -> - either.map { bs -> sequenceOf(res.value) + bs } - } - is Left -> Eval.now(res.value.left()) + is Right -> acc.add(res.value) + is Left -> return@traverseEither res } - }.value() + } + return acc.asSequence().right() +} fun Sequence.traverseValidated( semigroup: Semigroup, f: (A) -> Validated -): Validated> = - foldRight>>(Eval.now(emptySequence().valid())) { a, acc -> - acc.map { f(a).zip(semigroup, it) { b, bs -> sequenceOf(b) + bs } } - }.value() +): Validated> = fold(mutableListOf().valid() as Validated>) { acc, a -> + when (val res = f(a)) { + is Valid -> when (acc) { + is Valid -> acc.also { it.value.add(res.value) } + is Invalid -> acc + } + is Invalid -> when (acc) { + is Valid -> res + is Invalid -> semigroup.run { acc.value.combine(res.value).invalid() } + } + } +}.map { it.asSequence() } /** * splits an union into its component parts. diff --git a/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/map.kt b/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/map.kt index aaed4901c3a..20f767571dc 100644 --- a/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/map.kt +++ b/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/map.kt @@ -189,13 +189,16 @@ fun Map.flatMap(f: (Map.Entry) -> Map): Map = f(entry)[entry.key]?.let { Pair(entry.key, it) }.asIterable() }.toMap() -inline fun Map.traverseEither(f: (A) -> Either): Either> = - foldRight(emptyMap().right()) { (k, a), acc: Either> -> - when (val res = f(a)) { - is Right -> acc.map { bs: Map -> mapOf(k to res.value) + bs } - is Left -> res +inline fun Map.traverseEither(f: (A) -> Either): Either> { + val acc = mutableMapOf() + forEach { (k, v) -> + when (val res = f(v)) { + is Right -> acc[k] = res.value + is Left -> return@traverseEither res } } + return acc.right() +} fun Map>.sequenceEither(): Either> = traverseEither(::identity) @@ -203,19 +206,20 @@ fun Map>.sequenceEither(): Either> = inline fun Map.traverseValidated( semigroup: Semigroup, f: (A) -> Validated -): Validated> = - foldRight>>(emptyMap().valid()) { (k, a), acc -> - when (val res = f(a)) { - is Validated.Valid -> when (acc) { - is Validated.Valid -> acc.map { bs -> mapOf(k to res.value) + bs } - is Validated.Invalid -> acc +): Validated> { + return foldLeft(mutableMapOf().valid() as Validated>) { acc, (k, v) -> + when (val res = f(v)) { + is Valid -> when (acc) { + is Valid -> acc.also { it.value[k] = res.value } + is Invalid -> acc } - is Validated.Invalid -> when (acc) { - is Validated.Valid -> res - is Validated.Invalid -> Invalid(semigroup.run { res.value.combine(acc.value) }) + is Invalid -> when (acc) { + is Valid -> res + is Invalid -> semigroup.run { acc.value.combine(res.value).invalid() } } } } +} fun Map>.sequenceValidated(semigroup: Semigroup): Validated> = traverseValidated(semigroup, ::identity) diff --git a/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/IterableTest.kt b/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/IterableTest.kt index 5ea5204c00f..5b437b1f96c 100644 --- a/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/IterableTest.kt +++ b/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/IterableTest.kt @@ -12,20 +12,42 @@ import kotlin.math.min class IterableTest : UnitSpec() { init { "traverseEither stack-safe" { - (0..20_000).map { Either.Right(it) } - .sequenceEither() shouldBe Either.Right((0..20_000).toList()) + // also verifies result order and execution order (l to r) + val acc = mutableListOf() + val res = (0..20_000).traverseEither { a -> + acc.add(a) + Either.Right(a) + } + res shouldBe Either.Right(acc) + res shouldBe Either.Right((0..20_000).toList()) } "traverseEither short-circuit" { forAll(Gen.list(Gen.int())) { ints -> - (ints.map { Either.Right(it) } + Either.Left(Unit)) - .sequenceEither() == Either.Left(Unit) + val acc = mutableListOf() + val evens = ints.traverseEither { + if (it % 2 == 0) { + acc.add(it) + Either.Right(it) + } else Either.Left(it) + } + acc == ints.takeWhile { it % 2 == 0 } && + when (evens) { + is Either.Right -> evens.value == ints + is Either.Left -> evens.value == ints.first { it % 2 != 0 } + } } } "traverseValidated stack-safe" { - (0..20_000).map { Validated.Valid(it) } - .sequenceValidated(Semigroup.string()) shouldBe Validated.Valid((0..20_000).toList()) + // also verifies result order and execution order (l to r) + val acc = mutableListOf() + val res = (0..20_000).traverseValidated(Semigroup.string()) { + acc.add(it) + Validated.Valid(it) + } + res shouldBe Validated.Valid(acc) + res shouldBe Validated.Valid((0..20_000).toList()) } "traverseValidated acummulates" { diff --git a/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/MapKTest.kt b/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/MapKTest.kt index b4e9fc92015..12d19dd9a1a 100644 --- a/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/MapKTest.kt +++ b/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/MapKTest.kt @@ -8,12 +8,63 @@ import arrow.typeclasses.Monoid import arrow.typeclasses.Semigroup import io.kotlintest.properties.Gen import io.kotlintest.properties.forAll +import io.kotlintest.shouldBe class MapKTest : UnitSpec() { init { testLaws(MonoidLaws.laws(Monoid.map(Semigroup.int()), Gen.map(Gen.longSmall(), Gen.intSmall()))) + "traverseEither is stacksafe" { + val acc = mutableListOf() + val res = (0..20_000).map { it to it }.toMap().traverseEither { v -> + acc.add(v) + Either.Right(v) + } + res shouldBe acc.map { it to it }.toMap().right() + res shouldBe (0..20_000).map { it to it }.toMap().right() + } + + "traverseEither short-circuit" { + forAll(Gen.map(Gen.int(), Gen.int())) { ints -> + val acc = mutableListOf() + val evens = ints.traverseEither { + if (it % 2 == 0) { + acc.add(it) + Either.Right(it) + } else Either.Left(it) + } + acc == ints.values.takeWhile { it % 2 == 0 } && + when (evens) { + is Either.Right -> evens.value == ints + is Either.Left -> evens.value == ints.values.first { it % 2 != 0 } + } + } + } + + "traverseValidated is stacksafe" { + val acc = mutableListOf() + val res = (0..20_000).map { it to it }.toMap().traverseValidated(Semigroup.string()) { v -> + acc.add(v) + Validated.Valid(v) + } + res shouldBe acc.map { it to it }.toMap().valid() + res shouldBe (0..20_000).map { it to it }.toMap().valid() + } + + "traverseValidated acummulates" { + forAll(Gen.map(Gen.int(), Gen.int())) { ints -> + val res: ValidatedNel> = + ints.traverseValidated(Semigroup.nonEmptyList()) { i -> if (i % 2 == 0) i.validNel() else i.invalidNel() } + + val expected: ValidatedNel> = + NonEmptyList.fromList(ints.values.filterNot { it % 2 == 0 }) + .fold({ ints.entries.filter { (_, v) -> v % 2 == 0 }.map { (k, v) -> k to v }.toMap().validNel() }, { it.invalid() }) + + res == expected + } + } + "can align maps" { // aligned keySet is union of a's and b's keys forAll(Gen.map(Gen.long(), Gen.bool()), Gen.map(Gen.long(), Gen.bool())) { a, b -> diff --git a/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/NonEmptyListTest.kt b/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/NonEmptyListTest.kt index 94da056547d..9e78a8eec5e 100644 --- a/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/NonEmptyListTest.kt +++ b/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/NonEmptyListTest.kt @@ -6,6 +6,7 @@ import arrow.core.test.laws.SemigroupLaws import arrow.typeclasses.Semigroup import io.kotlintest.properties.Gen import io.kotlintest.properties.forAll +import io.kotlintest.shouldBe import kotlin.math.max import kotlin.math.min @@ -14,6 +15,57 @@ class NonEmptyListTest : UnitSpec() { testLaws(SemigroupLaws.laws(Semigroup.nonEmptyList(), Gen.nonEmptyList(Gen.int()))) + "traverseEither stack-safe" { + // also verifies result order and execution order (l to r) + val acc = mutableListOf() + val res = NonEmptyList.fromListUnsafe((0..20_000).toList()).traverseEither { a -> + acc.add(a) + Either.Right(a) + } + res shouldBe Either.Right(NonEmptyList.fromListUnsafe(acc)) + res shouldBe Either.Right(NonEmptyList.fromListUnsafe((0..20_000).toList())) + } + + "traverseEither short-circuit" { + forAll(Gen.nonEmptyList(Gen.int())) { ints -> + val acc = mutableListOf() + val evens = ints.traverseEither { + if (it % 2 == 0) { + acc.add(it) + Either.Right(it) + } else Either.Left(it) + } + acc == ints.takeWhile { it % 2 == 0 } && + when (evens) { + is Either.Right -> evens.value == ints + is Either.Left -> evens.value == ints.first { it % 2 != 0 } + } + } + } + + "traverseValidated stack-safe" { + // also verifies result order and execution order (l to r) + val acc = mutableListOf() + val res = (0..20_000).traverseValidated(Semigroup.string()) { + acc.add(it) + Validated.Valid(it) + } + res shouldBe Validated.Valid(acc) + res shouldBe Validated.Valid((0..20_000).toList()) + } + + "traverseValidated acummulates" { + forAll(Gen.nonEmptyList(Gen.int())) { ints -> + val res: ValidatedNel> = + ints.traverseValidated(Semigroup.nonEmptyList()) { i -> if (i % 2 == 0) i.validNel() else i.invalidNel() } + + val expected: ValidatedNel> = NonEmptyList.fromList(ints.filterNot { it % 2 == 0 }) + .fold({ NonEmptyList.fromListUnsafe(ints.filter { it % 2 == 0 }).validNel() }, { it.invalid() }) + + res == expected + } + } + "can align lists with different lengths" { forAll(Gen.nonEmptyList(Gen.bool()), Gen.nonEmptyList(Gen.bool())) { a, b -> a.align(b).size == max(a.size, b.size) diff --git a/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/SequenceKTest.kt b/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/SequenceKTest.kt index 04f33b86445..25045568c9a 100644 --- a/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/SequenceKTest.kt +++ b/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/SequenceKTest.kt @@ -4,6 +4,7 @@ import arrow.core.test.UnitSpec import arrow.core.test.generators.sequence import arrow.core.test.laws.MonoidLaws import arrow.typeclasses.Monoid +import arrow.typeclasses.Semigroup import io.kotlintest.matchers.sequences.shouldBeEmpty import io.kotlintest.properties.Gen import io.kotlintest.properties.forAll @@ -17,10 +18,43 @@ class SequenceKTest : UnitSpec() { testLaws(MonoidLaws.laws(Monoid.sequence(), Gen.sequence(Gen.int())) { s1, s2 -> s1.toList() == s2.toList() }) - "traverseEither is stacksafe over very long collections and short circuits properly" { - // This has to traverse 30k elements till it reaches None and terminates - generateSequence(0) { it + 1 }.map { if (it < 20_000) Either.Right(it) else Either.Left(Unit) } - .sequenceEither() shouldBe Either.Left(Unit) + "traverseEither stack-safe" { + // also verifies result order and execution order (l to r) + val acc = mutableListOf() + val res = generateSequence(0) { it + 1 }.traverseEither { a -> + if (a > 20_000) { + Either.Left(Unit) + } else { + acc.add(a) + Either.Right(a) + } + } + acc shouldBe (0..20_000).toList() + res shouldBe Either.Left(Unit) + } + + "traverseValidated stack-safe" { + // also verifies result order and execution order (l to r) + val acc = mutableListOf() + val res = (0..20_000).asSequence().traverseValidated(Semigroup.string()) { + acc.add(it) + Validated.Valid(it) + }.map { it.toList() } + res shouldBe Validated.Valid(acc) + res shouldBe Validated.Valid((0..20_000).toList()) + } + + "traverseValidated acummulates" { + forAll(Gen.list(Gen.int())) { ints -> + val ints = ints.asSequence() + val res: ValidatedNel> = ints.map { i -> if (i % 2 == 0) i.validNel() else i.invalidNel() } + .sequenceValidated(Semigroup.nonEmptyList()) + + val expected: ValidatedNel> = NonEmptyList.fromList(ints.filterNot { it % 2 == 0 }.toList()) + .fold({ ints.filter { it % 2 == 0 }.validNel() }, { it.invalid() }) + + res.map { it.toList() } == expected.map { it.toList() } + } } "zip3" {