Skip to content

Commit

Permalink
Traverse fixes (#2365)
Browse files Browse the repository at this point in the history
* FoldRight order

* Test and fix traverse implementations

* Undo foldRight changes

Co-authored-by: Simon Vergauwen <[email protected]>
  • Loading branch information
1Jajen1 and nomisRev authored Apr 6, 2021
1 parent 4393438 commit 6cecbc8
Show file tree
Hide file tree
Showing 8 changed files with 238 additions and 56 deletions.
21 changes: 12 additions & 9 deletions arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/Iterable.kt
Original file line number Diff line number Diff line change
Expand Up @@ -287,13 +287,16 @@ inline fun <A, B> Iterable<A>.foldRight(initial: B, operation: (A, acc: B) -> B)
else -> reversed().fold(initial) { acc, a -> operation(a, acc) }
}

inline fun <E, A, B> Iterable<A>.traverseEither(f: (A) -> Either<E, B>): Either<E, List<B>> =
foldRight<A, Either<E, List<B>>>(emptyList<B>().right()) { a, acc ->
inline fun <E, A, B> Iterable<A>.traverseEither(f: (A) -> Either<E, B>): Either<E, List<B>> {
val acc = mutableListOf<B>()
forEach { a ->
when (val res = f(a)) {
is Right -> acc.map { bs -> listOf(res.value) + bs }
is Left -> res
is Right -> acc.add(res.value)
is Left -> return@traverseEither res
}
}
return acc.right()
}

fun <E, A> Iterable<Either<E, A>>.sequenceEither(): Either<E, List<A>> =
traverseEither(::identity)
Expand All @@ -302,15 +305,15 @@ inline fun <E, A, B> Iterable<A>.traverseValidated(
semigroup: Semigroup<E>,
f: (A) -> Validated<E, B>
): Validated<E, List<B>> = semigroup.run {
foldRight<A, Validated<E, List<B>>>(emptyList<B>().valid()) { a, acc ->
fold(Valid(mutableListOf<B>()) as Validated<E, MutableList<B>>) { acc, a ->
when (val res = f(a)) {
is Validated.Valid -> when (acc) {
is Validated.Valid -> acc.map { bs -> listOf(res.value) + bs }
is Validated.Invalid -> acc
is Valid -> acc.also { it.value.add(res.value) }
is Invalid -> acc
}
is Validated.Invalid -> when (acc) {
is Validated.Valid -> res
is Validated.Invalid -> Invalid(res.value.combine(acc.value))
is Valid -> res
is Invalid -> acc.value.combine(res.value).invalid()
}
}
}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -412,13 +412,17 @@ fun <A, B, C> NonEmptyList<C>.unzip(f: (C) -> Pair<A, B>): Pair<NonEmptyList<A>,
}
}

inline fun <E, A, B> NonEmptyList<A>.traverseEither(f: (A) -> Either<E, B>): Either<E, NonEmptyList<B>> =
foldRight(f(head).map(::nonEmptyListOf)) { a, acc ->
inline fun <E, A, B> NonEmptyList<A>.traverseEither(f: (A) -> Either<E, B>): Either<E, NonEmptyList<B>> {
val acc = mutableListOf<B>()
forEach { a ->
when (val res = f(a)) {
is Right -> acc.map { bs -> nonEmptyListOf(res.value) + bs }
is Left -> res
is Right -> acc.add(res.value)
is Left -> return@traverseEither res
}
}
// Safe due to traverse laws
return NonEmptyList.fromListUnsafe(acc).right()
}

fun <E, A> NonEmptyList<Either<E, A>>.sequenceEither(): Either<E, NonEmptyList<A>> =
traverseEither(::identity)
Expand All @@ -427,18 +431,18 @@ inline fun <E, A, B> NonEmptyList<A>.traverseValidated(
semigroup: Semigroup<E>,
f: (A) -> Validated<E, B>
): Validated<E, NonEmptyList<B>> =
foldRight(f(head).map(::nonEmptyListOf)) { a, acc ->
fold(mutableListOf<B>().valid() as Validated<E, MutableList<B>>) { acc, a ->
when (val res = f(a)) {
is Validated.Valid -> when (acc) {
is Validated.Valid -> acc.map { bs -> nonEmptyListOf(res.value) + bs }
is Validated.Invalid -> acc
is Valid -> when (acc) {
is Valid -> acc.also { it.value.add(res.value) }
is Invalid -> acc
}
is Validated.Invalid -> when (acc) {
is Validated.Valid -> res
is Validated.Invalid -> Invalid(semigroup.run { res.value.combine(acc.value) })
is Invalid -> when (acc) {
is Valid -> res
is Invalid -> semigroup.run { acc.value.combine(res.value).invalid() }
}
}
}
}.map { NonEmptyList.fromListUnsafe(it) }

fun <E, A> NonEmptyList<Validated<E, A>>.sequenceValidated(semigroup: Semigroup<E>): Validated<E, NonEmptyList<A>> =
traverseValidated(semigroup, ::identity)
34 changes: 23 additions & 11 deletions arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/Sequence.kt
Original file line number Diff line number Diff line change
Expand Up @@ -643,23 +643,35 @@ fun <A> Sequence<A>.split(): Pair<Sequence<A>, A>? =
fun <A> Sequence<A>.tail(): Sequence<A> =
drop(1)

fun <E, A, B> Sequence<A>.traverseEither(f: (A) -> Either<E, B>): Either<E, Sequence<B>> =
foldRight<A, Either<E, Sequence<B>>>(Eval.now(sequenceOf<B>().right())) { a, eval ->
fun <E, A, B> Sequence<A>.traverseEither(f: (A) -> Either<E, B>): Either<E, Sequence<B>> {
// Note: Using a mutable list here avoids the stackoverflows one can accidentally create when using
// Sequence.plus instead. But we don't convert the sequence to a list beforehand to avoid
// forcing too much of the sequence to be evaluated.
val acc = mutableListOf<B>()
forEach { a ->
when (val res = f(a)) {
is Right -> eval.map { either ->
either.map { bs -> sequenceOf(res.value) + bs }
}
is Left -> Eval.now(res.value.left())
is Right -> acc.add(res.value)
is Left -> return@traverseEither res
}
}.value()
}
return acc.asSequence().right()
}

fun <E, A, B> Sequence<A>.traverseValidated(
semigroup: Semigroup<E>,
f: (A) -> Validated<E, B>
): Validated<E, Sequence<B>> =
foldRight<A, Validated<E, Sequence<B>>>(Eval.now(emptySequence<B>().valid())) { a, acc ->
acc.map { f(a).zip(semigroup, it) { b, bs -> sequenceOf(b) + bs } }
}.value()
): Validated<E, Sequence<B>> = fold(mutableListOf<B>().valid() as Validated<E, MutableList<B>>) { acc, a ->
when (val res = f(a)) {
is Valid -> when (acc) {
is Valid -> acc.also { it.value.add(res.value) }
is Invalid -> acc
}
is Invalid -> when (acc) {
is Valid -> res
is Invalid -> semigroup.run { acc.value.combine(res.value).invalid() }
}
}
}.map { it.asSequence() }

/**
* splits an union into its component parts.
Expand Down
32 changes: 18 additions & 14 deletions arrow-libs/core/arrow-core/src/main/kotlin/arrow/core/map.kt
Original file line number Diff line number Diff line change
Expand Up @@ -189,33 +189,37 @@ fun <K, A, B> Map<K, A>.flatMap(f: (Map.Entry<K, A>) -> Map<K, B>): Map<K, B> =
f(entry)[entry.key]?.let { Pair(entry.key, it) }.asIterable()
}.toMap()

inline fun <K, E, A, B> Map<K, A>.traverseEither(f: (A) -> Either<E, B>): Either<E, Map<K, B>> =
foldRight(emptyMap<K, B>().right()) { (k, a), acc: Either<E, Map<K, B>> ->
when (val res = f(a)) {
is Right -> acc.map { bs: Map<K, B> -> mapOf(k to res.value) + bs }
is Left -> res
inline fun <K, E, A, B> Map<K, A>.traverseEither(f: (A) -> Either<E, B>): Either<E, Map<K, B>> {
val acc = mutableMapOf<K, B>()
forEach { (k, v) ->
when (val res = f(v)) {
is Right -> acc[k] = res.value
is Left -> return@traverseEither res
}
}
return acc.right()
}

fun <K, E, A> Map<K, Either<E, A>>.sequenceEither(): Either<E, Map<K, A>> =
traverseEither(::identity)

inline fun <K, E, A, B> Map<K, A>.traverseValidated(
semigroup: Semigroup<E>,
f: (A) -> Validated<E, B>
): Validated<E, Map<K, B>> =
foldRight<K, A, Validated<E, Map<K, B>>>(emptyMap<K, B>().valid()) { (k, a), acc ->
when (val res = f(a)) {
is Validated.Valid -> when (acc) {
is Validated.Valid -> acc.map { bs -> mapOf(k to res.value) + bs }
is Validated.Invalid -> acc
): Validated<E, Map<K, B>> {
return foldLeft(mutableMapOf<K, B>().valid() as Validated<E, MutableMap<K, B>>) { acc, (k, v) ->
when (val res = f(v)) {
is Valid -> when (acc) {
is Valid -> acc.also { it.value[k] = res.value }
is Invalid -> acc
}
is Validated.Invalid -> when (acc) {
is Validated.Valid -> res
is Validated.Invalid -> Invalid(semigroup.run { res.value.combine(acc.value) })
is Invalid -> when (acc) {
is Valid -> res
is Invalid -> semigroup.run { acc.value.combine(res.value).invalid() }
}
}
}
}

fun <K, E, A> Map<K, Validated<E, A>>.sequenceValidated(semigroup: Semigroup<E>): Validated<E, Map<K, A>> =
traverseValidated(semigroup, ::identity)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -12,20 +12,42 @@ import kotlin.math.min
class IterableTest : UnitSpec() {
init {
"traverseEither stack-safe" {
(0..20_000).map { Either.Right(it) }
.sequenceEither() shouldBe Either.Right((0..20_000).toList())
// also verifies result order and execution order (l to r)
val acc = mutableListOf<Int>()
val res = (0..20_000).traverseEither { a ->
acc.add(a)
Either.Right(a)
}
res shouldBe Either.Right(acc)
res shouldBe Either.Right((0..20_000).toList())
}

"traverseEither short-circuit" {
forAll(Gen.list(Gen.int())) { ints ->
(ints.map { Either.Right(it) } + Either.Left(Unit))
.sequenceEither() == Either.Left(Unit)
val acc = mutableListOf<Int>()
val evens = ints.traverseEither {
if (it % 2 == 0) {
acc.add(it)
Either.Right(it)
} else Either.Left(it)
}
acc == ints.takeWhile { it % 2 == 0 } &&
when (evens) {
is Either.Right -> evens.value == ints
is Either.Left -> evens.value == ints.first { it % 2 != 0 }
}
}
}

"traverseValidated stack-safe" {
(0..20_000).map { Validated.Valid(it) }
.sequenceValidated(Semigroup.string()) shouldBe Validated.Valid((0..20_000).toList())
// also verifies result order and execution order (l to r)
val acc = mutableListOf<Int>()
val res = (0..20_000).traverseValidated(Semigroup.string()) {
acc.add(it)
Validated.Valid(it)
}
res shouldBe Validated.Valid(acc)
res shouldBe Validated.Valid((0..20_000).toList())
}

"traverseValidated acummulates" {
Expand Down
51 changes: 51 additions & 0 deletions arrow-libs/core/arrow-core/src/test/kotlin/arrow/core/MapKTest.kt
Original file line number Diff line number Diff line change
Expand Up @@ -8,12 +8,63 @@ import arrow.typeclasses.Monoid
import arrow.typeclasses.Semigroup
import io.kotlintest.properties.Gen
import io.kotlintest.properties.forAll
import io.kotlintest.shouldBe

class MapKTest : UnitSpec() {

init {
testLaws(MonoidLaws.laws(Monoid.map(Semigroup.int()), Gen.map(Gen.longSmall(), Gen.intSmall())))

"traverseEither is stacksafe" {
val acc = mutableListOf<Int>()
val res = (0..20_000).map { it to it }.toMap().traverseEither { v ->
acc.add(v)
Either.Right(v)
}
res shouldBe acc.map { it to it }.toMap().right()
res shouldBe (0..20_000).map { it to it }.toMap().right()
}

"traverseEither short-circuit" {
forAll(Gen.map(Gen.int(), Gen.int())) { ints ->
val acc = mutableListOf<Int>()
val evens = ints.traverseEither {
if (it % 2 == 0) {
acc.add(it)
Either.Right(it)
} else Either.Left(it)
}
acc == ints.values.takeWhile { it % 2 == 0 } &&
when (evens) {
is Either.Right -> evens.value == ints
is Either.Left -> evens.value == ints.values.first { it % 2 != 0 }
}
}
}

"traverseValidated is stacksafe" {
val acc = mutableListOf<Int>()
val res = (0..20_000).map { it to it }.toMap().traverseValidated(Semigroup.string()) { v ->
acc.add(v)
Validated.Valid(v)
}
res shouldBe acc.map { it to it }.toMap().valid()
res shouldBe (0..20_000).map { it to it }.toMap().valid()
}

"traverseValidated acummulates" {
forAll(Gen.map(Gen.int(), Gen.int())) { ints ->
val res: ValidatedNel<Int, Map<Int, Int>> =
ints.traverseValidated(Semigroup.nonEmptyList()) { i -> if (i % 2 == 0) i.validNel() else i.invalidNel() }

val expected: ValidatedNel<Int, Map<Int, Int>> =
NonEmptyList.fromList(ints.values.filterNot { it % 2 == 0 })
.fold({ ints.entries.filter { (_, v) -> v % 2 == 0 }.map { (k, v) -> k to v }.toMap().validNel() }, { it.invalid() })

res == expected
}
}

"can align maps" {
// aligned keySet is union of a's and b's keys
forAll(Gen.map(Gen.long(), Gen.bool()), Gen.map(Gen.long(), Gen.bool())) { a, b ->
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import arrow.core.test.laws.SemigroupLaws
import arrow.typeclasses.Semigroup
import io.kotlintest.properties.Gen
import io.kotlintest.properties.forAll
import io.kotlintest.shouldBe
import kotlin.math.max
import kotlin.math.min

Expand All @@ -14,6 +15,57 @@ class NonEmptyListTest : UnitSpec() {

testLaws(SemigroupLaws.laws(Semigroup.nonEmptyList(), Gen.nonEmptyList(Gen.int())))

"traverseEither stack-safe" {
// also verifies result order and execution order (l to r)
val acc = mutableListOf<Int>()
val res = NonEmptyList.fromListUnsafe((0..20_000).toList()).traverseEither { a ->
acc.add(a)
Either.Right(a)
}
res shouldBe Either.Right(NonEmptyList.fromListUnsafe(acc))
res shouldBe Either.Right(NonEmptyList.fromListUnsafe((0..20_000).toList()))
}

"traverseEither short-circuit" {
forAll(Gen.nonEmptyList(Gen.int())) { ints ->
val acc = mutableListOf<Int>()
val evens = ints.traverseEither {
if (it % 2 == 0) {
acc.add(it)
Either.Right(it)
} else Either.Left(it)
}
acc == ints.takeWhile { it % 2 == 0 } &&
when (evens) {
is Either.Right -> evens.value == ints
is Either.Left -> evens.value == ints.first { it % 2 != 0 }
}
}
}

"traverseValidated stack-safe" {
// also verifies result order and execution order (l to r)
val acc = mutableListOf<Int>()
val res = (0..20_000).traverseValidated(Semigroup.string()) {
acc.add(it)
Validated.Valid(it)
}
res shouldBe Validated.Valid(acc)
res shouldBe Validated.Valid((0..20_000).toList())
}

"traverseValidated acummulates" {
forAll(Gen.nonEmptyList(Gen.int())) { ints ->
val res: ValidatedNel<Int, NonEmptyList<Int>> =
ints.traverseValidated(Semigroup.nonEmptyList()) { i -> if (i % 2 == 0) i.validNel() else i.invalidNel() }

val expected: ValidatedNel<Int, NonEmptyList<Int>> = NonEmptyList.fromList(ints.filterNot { it % 2 == 0 })
.fold({ NonEmptyList.fromListUnsafe(ints.filter { it % 2 == 0 }).validNel() }, { it.invalid() })

res == expected
}
}

"can align lists with different lengths" {
forAll(Gen.nonEmptyList(Gen.bool()), Gen.nonEmptyList(Gen.bool())) { a, b ->
a.align(b).size == max(a.size, b.size)
Expand Down
Loading

0 comments on commit 6cecbc8

Please sign in to comment.