diff --git a/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/Iterable.kt b/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/Iterable.kt
index 3804f9ee1df..8da56d0f1a1 100644
--- a/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/Iterable.kt
+++ b/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/Iterable.kt
@@ -287,13 +287,16 @@ inline fun Iterable.foldRight(initial: B, operation: (A, acc: B) -> B)
else -> reversed().fold(initial) { acc, a -> operation(a, acc) }
}
-inline fun Iterable.traverseEither(f: (A) -> Either): Either> =
- foldRight>>(emptyList().right()) { a, acc ->
+inline fun Iterable.traverseEither(f: (A) -> Either): Either> {
+ val acc = mutableListOf()
+ forEach { a ->
when (val res = f(a)) {
- is Right -> acc.map { bs -> listOf(res.value) + bs }
- is Left -> res
+ is Right -> acc.add(res.value)
+ is Left -> return@traverseEither res
}
}
+ return acc.right()
+}
fun Iterable>.sequenceEither(): Either> =
traverseEither(::identity)
@@ -302,15 +305,15 @@ inline fun Iterable.traverseValidated(
semigroup: Semigroup,
f: (A) -> Validated
): Validated> = semigroup.run {
- foldRight>>(emptyList().valid()) { a, acc ->
+ fold(Valid(mutableListOf()) as Validated>) { acc, a ->
when (val res = f(a)) {
is Validated.Valid -> when (acc) {
- is Validated.Valid -> acc.map { bs -> listOf(res.value) + bs }
- is Validated.Invalid -> acc
+ is Valid -> acc.also { it.value.add(res.value) }
+ is Invalid -> acc
}
is Validated.Invalid -> when (acc) {
- is Validated.Valid -> res
- is Validated.Invalid -> Invalid(res.value.combine(acc.value))
+ is Valid -> res
+ is Invalid -> acc.value.combine(res.value).invalid()
}
}
}
diff --git a/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/NonEmptyList.kt b/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/NonEmptyList.kt
index 2a039a3502b..e54e14853e8 100644
--- a/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/NonEmptyList.kt
+++ b/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/NonEmptyList.kt
@@ -412,13 +412,17 @@ fun NonEmptyList.unzip(f: (C) -> Pair): Pair,
}
}
-inline fun NonEmptyList.traverseEither(f: (A) -> Either): Either> =
- foldRight(f(head).map(::nonEmptyListOf)) { a, acc ->
+inline fun NonEmptyList.traverseEither(f: (A) -> Either): Either> {
+ val acc = mutableListOf()
+ forEach { a ->
when (val res = f(a)) {
- is Right -> acc.map { bs -> nonEmptyListOf(res.value) + bs }
- is Left -> res
+ is Right -> acc.add(res.value)
+ is Left -> return@traverseEither res
}
}
+ // Safe due to traverse laws
+ return NonEmptyList.fromListUnsafe(acc).right()
+}
fun NonEmptyList>.sequenceEither(): Either> =
traverseEither(::identity)
@@ -427,18 +431,18 @@ inline fun NonEmptyList.traverseValidated(
semigroup: Semigroup,
f: (A) -> Validated
): Validated> =
- foldRight(f(head).map(::nonEmptyListOf)) { a, acc ->
+ fold(mutableListOf().valid() as Validated>) { acc, a ->
when (val res = f(a)) {
- is Validated.Valid -> when (acc) {
- is Validated.Valid -> acc.map { bs -> nonEmptyListOf(res.value) + bs }
- is Validated.Invalid -> acc
+ is Valid -> when (acc) {
+ is Valid -> acc.also { it.value.add(res.value) }
+ is Invalid -> acc
}
- is Validated.Invalid -> when (acc) {
- is Validated.Valid -> res
- is Validated.Invalid -> Invalid(semigroup.run { res.value.combine(acc.value) })
+ is Invalid -> when (acc) {
+ is Valid -> res
+ is Invalid -> semigroup.run { acc.value.combine(res.value).invalid() }
}
}
- }
+ }.map { NonEmptyList.fromListUnsafe(it) }
fun NonEmptyList>.sequenceValidated(semigroup: Semigroup): Validated> =
traverseValidated(semigroup, ::identity)
diff --git a/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/Sequence.kt b/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/Sequence.kt
index 8d4e6ee8468..4a8eae12f60 100644
--- a/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/Sequence.kt
+++ b/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/Sequence.kt
@@ -643,23 +643,35 @@ fun Sequence.split(): Pair, A>? =
fun Sequence.tail(): Sequence =
drop(1)
-fun Sequence.traverseEither(f: (A) -> Either): Either> =
- foldRight>>(Eval.now(sequenceOf().right())) { a, eval ->
+fun Sequence.traverseEither(f: (A) -> Either): Either> {
+ // Note: Using a mutable list here avoids the stackoverflows one can accidentally create when using
+ // Sequence.plus instead. But we don't convert the sequence to a list beforehand to avoid
+ // forcing too much of the sequence to be evaluated.
+ val acc = mutableListOf()
+ forEach { a ->
when (val res = f(a)) {
- is Right -> eval.map { either ->
- either.map { bs -> sequenceOf(res.value) + bs }
- }
- is Left -> Eval.now(res.value.left())
+ is Right -> acc.add(res.value)
+ is Left -> return@traverseEither res
}
- }.value()
+ }
+ return acc.asSequence().right()
+}
fun Sequence.traverseValidated(
semigroup: Semigroup,
f: (A) -> Validated
-): Validated> =
- foldRight>>(Eval.now(emptySequence().valid())) { a, acc ->
- acc.map { f(a).zip(semigroup, it) { b, bs -> sequenceOf(b) + bs } }
- }.value()
+): Validated> = fold(mutableListOf().valid() as Validated>) { acc, a ->
+ when (val res = f(a)) {
+ is Valid -> when (acc) {
+ is Valid -> acc.also { it.value.add(res.value) }
+ is Invalid -> acc
+ }
+ is Invalid -> when (acc) {
+ is Valid -> res
+ is Invalid -> semigroup.run { acc.value.combine(res.value).invalid() }
+ }
+ }
+}.map { it.asSequence() }
/**
* splits an union into its component parts.
diff --git a/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/map.kt b/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/map.kt
index aaed4901c3a..20f767571dc 100644
--- a/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/map.kt
+++ b/arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/map.kt
@@ -189,13 +189,16 @@ fun Map.flatMap(f: (Map.Entry) -> Map): Map =
f(entry)[entry.key]?.let { Pair(entry.key, it) }.asIterable()
}.toMap()
-inline fun Map.traverseEither(f: (A) -> Either): Either> =
- foldRight(emptyMap().right()) { (k, a), acc: Either> ->
- when (val res = f(a)) {
- is Right -> acc.map { bs: Map -> mapOf(k to res.value) + bs }
- is Left -> res
+inline fun Map.traverseEither(f: (A) -> Either): Either> {
+ val acc = mutableMapOf()
+ forEach { (k, v) ->
+ when (val res = f(v)) {
+ is Right -> acc[k] = res.value
+ is Left -> return@traverseEither res
}
}
+ return acc.right()
+}
fun Map>.sequenceEither(): Either> =
traverseEither(::identity)
@@ -203,19 +206,20 @@ fun Map>.sequenceEither(): Either> =
inline fun Map.traverseValidated(
semigroup: Semigroup,
f: (A) -> Validated
-): Validated> =
- foldRight>>(emptyMap().valid()) { (k, a), acc ->
- when (val res = f(a)) {
- is Validated.Valid -> when (acc) {
- is Validated.Valid -> acc.map { bs -> mapOf(k to res.value) + bs }
- is Validated.Invalid -> acc
+): Validated> {
+ return foldLeft(mutableMapOf().valid() as Validated>) { acc, (k, v) ->
+ when (val res = f(v)) {
+ is Valid -> when (acc) {
+ is Valid -> acc.also { it.value[k] = res.value }
+ is Invalid -> acc
}
- is Validated.Invalid -> when (acc) {
- is Validated.Valid -> res
- is Validated.Invalid -> Invalid(semigroup.run { res.value.combine(acc.value) })
+ is Invalid -> when (acc) {
+ is Valid -> res
+ is Invalid -> semigroup.run { acc.value.combine(res.value).invalid() }
}
}
}
+}
fun Map>.sequenceValidated(semigroup: Semigroup): Validated> =
traverseValidated(semigroup, ::identity)
diff --git a/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/IterableTest.kt b/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/IterableTest.kt
index 5ea5204c00f..5b437b1f96c 100644
--- a/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/IterableTest.kt
+++ b/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/IterableTest.kt
@@ -12,20 +12,42 @@ import kotlin.math.min
class IterableTest : UnitSpec() {
init {
"traverseEither stack-safe" {
- (0..20_000).map { Either.Right(it) }
- .sequenceEither() shouldBe Either.Right((0..20_000).toList())
+ // also verifies result order and execution order (l to r)
+ val acc = mutableListOf()
+ val res = (0..20_000).traverseEither { a ->
+ acc.add(a)
+ Either.Right(a)
+ }
+ res shouldBe Either.Right(acc)
+ res shouldBe Either.Right((0..20_000).toList())
}
"traverseEither short-circuit" {
forAll(Gen.list(Gen.int())) { ints ->
- (ints.map { Either.Right(it) } + Either.Left(Unit))
- .sequenceEither() == Either.Left(Unit)
+ val acc = mutableListOf()
+ val evens = ints.traverseEither {
+ if (it % 2 == 0) {
+ acc.add(it)
+ Either.Right(it)
+ } else Either.Left(it)
+ }
+ acc == ints.takeWhile { it % 2 == 0 } &&
+ when (evens) {
+ is Either.Right -> evens.value == ints
+ is Either.Left -> evens.value == ints.first { it % 2 != 0 }
+ }
}
}
"traverseValidated stack-safe" {
- (0..20_000).map { Validated.Valid(it) }
- .sequenceValidated(Semigroup.string()) shouldBe Validated.Valid((0..20_000).toList())
+ // also verifies result order and execution order (l to r)
+ val acc = mutableListOf()
+ val res = (0..20_000).traverseValidated(Semigroup.string()) {
+ acc.add(it)
+ Validated.Valid(it)
+ }
+ res shouldBe Validated.Valid(acc)
+ res shouldBe Validated.Valid((0..20_000).toList())
}
"traverseValidated acummulates" {
diff --git a/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/MapKTest.kt b/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/MapKTest.kt
index b4e9fc92015..12d19dd9a1a 100644
--- a/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/MapKTest.kt
+++ b/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/MapKTest.kt
@@ -8,12 +8,63 @@ import arrow.typeclasses.Monoid
import arrow.typeclasses.Semigroup
import io.kotlintest.properties.Gen
import io.kotlintest.properties.forAll
+import io.kotlintest.shouldBe
class MapKTest : UnitSpec() {
init {
testLaws(MonoidLaws.laws(Monoid.map(Semigroup.int()), Gen.map(Gen.longSmall(), Gen.intSmall())))
+ "traverseEither is stacksafe" {
+ val acc = mutableListOf()
+ val res = (0..20_000).map { it to it }.toMap().traverseEither { v ->
+ acc.add(v)
+ Either.Right(v)
+ }
+ res shouldBe acc.map { it to it }.toMap().right()
+ res shouldBe (0..20_000).map { it to it }.toMap().right()
+ }
+
+ "traverseEither short-circuit" {
+ forAll(Gen.map(Gen.int(), Gen.int())) { ints ->
+ val acc = mutableListOf()
+ val evens = ints.traverseEither {
+ if (it % 2 == 0) {
+ acc.add(it)
+ Either.Right(it)
+ } else Either.Left(it)
+ }
+ acc == ints.values.takeWhile { it % 2 == 0 } &&
+ when (evens) {
+ is Either.Right -> evens.value == ints
+ is Either.Left -> evens.value == ints.values.first { it % 2 != 0 }
+ }
+ }
+ }
+
+ "traverseValidated is stacksafe" {
+ val acc = mutableListOf()
+ val res = (0..20_000).map { it to it }.toMap().traverseValidated(Semigroup.string()) { v ->
+ acc.add(v)
+ Validated.Valid(v)
+ }
+ res shouldBe acc.map { it to it }.toMap().valid()
+ res shouldBe (0..20_000).map { it to it }.toMap().valid()
+ }
+
+ "traverseValidated acummulates" {
+ forAll(Gen.map(Gen.int(), Gen.int())) { ints ->
+ val res: ValidatedNel> =
+ ints.traverseValidated(Semigroup.nonEmptyList()) { i -> if (i % 2 == 0) i.validNel() else i.invalidNel() }
+
+ val expected: ValidatedNel> =
+ NonEmptyList.fromList(ints.values.filterNot { it % 2 == 0 })
+ .fold({ ints.entries.filter { (_, v) -> v % 2 == 0 }.map { (k, v) -> k to v }.toMap().validNel() }, { it.invalid() })
+
+ res == expected
+ }
+ }
+
"can align maps" {
// aligned keySet is union of a's and b's keys
forAll(Gen.map(Gen.long(), Gen.bool()), Gen.map(Gen.long(), Gen.bool())) { a, b ->
diff --git a/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/NonEmptyListTest.kt b/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/NonEmptyListTest.kt
index 94da056547d..9e78a8eec5e 100644
--- a/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/NonEmptyListTest.kt
+++ b/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/NonEmptyListTest.kt
@@ -6,6 +6,7 @@ import arrow.core.test.laws.SemigroupLaws
import arrow.typeclasses.Semigroup
import io.kotlintest.properties.Gen
import io.kotlintest.properties.forAll
+import io.kotlintest.shouldBe
import kotlin.math.max
import kotlin.math.min
@@ -14,6 +15,57 @@ class NonEmptyListTest : UnitSpec() {
testLaws(SemigroupLaws.laws(Semigroup.nonEmptyList(), Gen.nonEmptyList(Gen.int())))
+ "traverseEither stack-safe" {
+ // also verifies result order and execution order (l to r)
+ val acc = mutableListOf()
+ val res = NonEmptyList.fromListUnsafe((0..20_000).toList()).traverseEither { a ->
+ acc.add(a)
+ Either.Right(a)
+ }
+ res shouldBe Either.Right(NonEmptyList.fromListUnsafe(acc))
+ res shouldBe Either.Right(NonEmptyList.fromListUnsafe((0..20_000).toList()))
+ }
+
+ "traverseEither short-circuit" {
+ forAll(Gen.nonEmptyList(Gen.int())) { ints ->
+ val acc = mutableListOf()
+ val evens = ints.traverseEither {
+ if (it % 2 == 0) {
+ acc.add(it)
+ Either.Right(it)
+ } else Either.Left(it)
+ }
+ acc == ints.takeWhile { it % 2 == 0 } &&
+ when (evens) {
+ is Either.Right -> evens.value == ints
+ is Either.Left -> evens.value == ints.first { it % 2 != 0 }
+ }
+ }
+ }
+
+ "traverseValidated stack-safe" {
+ // also verifies result order and execution order (l to r)
+ val acc = mutableListOf()
+ val res = (0..20_000).traverseValidated(Semigroup.string()) {
+ acc.add(it)
+ Validated.Valid(it)
+ }
+ res shouldBe Validated.Valid(acc)
+ res shouldBe Validated.Valid((0..20_000).toList())
+ }
+
+ "traverseValidated acummulates" {
+ forAll(Gen.nonEmptyList(Gen.int())) { ints ->
+ val res: ValidatedNel> =
+ ints.traverseValidated(Semigroup.nonEmptyList()) { i -> if (i % 2 == 0) i.validNel() else i.invalidNel() }
+
+ val expected: ValidatedNel> = NonEmptyList.fromList(ints.filterNot { it % 2 == 0 })
+ .fold({ NonEmptyList.fromListUnsafe(ints.filter { it % 2 == 0 }).validNel() }, { it.invalid() })
+
+ res == expected
+ }
+ }
+
"can align lists with different lengths" {
forAll(Gen.nonEmptyList(Gen.bool()), Gen.nonEmptyList(Gen.bool())) { a, b ->
a.align(b).size == max(a.size, b.size)
diff --git a/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/SequenceKTest.kt b/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/SequenceKTest.kt
index 04f33b86445..25045568c9a 100644
--- a/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/SequenceKTest.kt
+++ b/arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/SequenceKTest.kt
@@ -4,6 +4,7 @@ import arrow.core.test.UnitSpec
import arrow.core.test.generators.sequence
import arrow.core.test.laws.MonoidLaws
import arrow.typeclasses.Monoid
+import arrow.typeclasses.Semigroup
import io.kotlintest.matchers.sequences.shouldBeEmpty
import io.kotlintest.properties.Gen
import io.kotlintest.properties.forAll
@@ -17,10 +18,43 @@ class SequenceKTest : UnitSpec() {
testLaws(MonoidLaws.laws(Monoid.sequence(), Gen.sequence(Gen.int())) { s1, s2 -> s1.toList() == s2.toList() })
- "traverseEither is stacksafe over very long collections and short circuits properly" {
- // This has to traverse 30k elements till it reaches None and terminates
- generateSequence(0) { it + 1 }.map { if (it < 20_000) Either.Right(it) else Either.Left(Unit) }
- .sequenceEither() shouldBe Either.Left(Unit)
+ "traverseEither stack-safe" {
+ // also verifies result order and execution order (l to r)
+ val acc = mutableListOf()
+ val res = generateSequence(0) { it + 1 }.traverseEither { a ->
+ if (a > 20_000) {
+ Either.Left(Unit)
+ } else {
+ acc.add(a)
+ Either.Right(a)
+ }
+ }
+ acc shouldBe (0..20_000).toList()
+ res shouldBe Either.Left(Unit)
+ }
+
+ "traverseValidated stack-safe" {
+ // also verifies result order and execution order (l to r)
+ val acc = mutableListOf()
+ val res = (0..20_000).asSequence().traverseValidated(Semigroup.string()) {
+ acc.add(it)
+ Validated.Valid(it)
+ }.map { it.toList() }
+ res shouldBe Validated.Valid(acc)
+ res shouldBe Validated.Valid((0..20_000).toList())
+ }
+
+ "traverseValidated acummulates" {
+ forAll(Gen.list(Gen.int())) { ints ->
+ val ints = ints.asSequence()
+ val res: ValidatedNel> = ints.map { i -> if (i % 2 == 0) i.validNel() else i.invalidNel() }
+ .sequenceValidated(Semigroup.nonEmptyList())
+
+ val expected: ValidatedNel> = NonEmptyList.fromList(ints.filterNot { it % 2 == 0 }.toList())
+ .fold({ ints.filter { it % 2 == 0 }.validNel() }, { it.invalid() })
+
+ res.map { it.toList() } == expected.map { it.toList() }
+ }
}
"zip3" {