Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Avoid storing more transformed elements than necessary #3376

Merged
merged 6 commits into from
Feb 18, 2024
Merged
Show file tree
Hide file tree
Changes from 1 commit
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 2 additions & 0 deletions arrow-libs/core/arrow-core/api/arrow-core.api
Original file line number Diff line number Diff line change
Expand Up @@ -3614,6 +3614,8 @@ public final class arrow/core/raise/RaiseKt {
public static final fun forEachAccumulating (Larrow/core/raise/Raise;Ljava/util/Iterator;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function2;)V
public static final fun forEachAccumulating (Larrow/core/raise/Raise;Lkotlin/sequences/Sequence;Lkotlin/jvm/functions/Function2;)V
public static final fun forEachAccumulating (Larrow/core/raise/Raise;Lkotlin/sequences/Sequence;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function2;)V
public static final fun forEachAccumulatingDouble (Larrow/core/raise/Raise;Ljava/util/Iterator;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function2;)V
public static final fun forEachAccumulatingDouble (Larrow/core/raise/Raise;Ljava/util/Iterator;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function2;Lkotlin/jvm/functions/Function2;)V
public static final fun get (Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
public static final fun get (Lkotlin/jvm/functions/Function2;Lkotlin/coroutines/Continuation;)Ljava/lang/Object;
public static final fun getOrElse (Lkotlin/jvm/functions/Function1;Lkotlin/jvm/functions/Function1;)Ljava/lang/Object;
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -14,7 +14,6 @@ import arrow.core.Validated
import arrow.core.ValidatedDeprMsg
import arrow.core.collectionSizeOrDefault
import arrow.core.ValidatedNel
import arrow.core.mapOrAccumulate
import arrow.core.nonEmptyListOf
import arrow.core.toNonEmptyListOrNull
import arrow.core.toNonEmptySetOrNull
Expand Down Expand Up @@ -521,10 +520,23 @@ public inline fun <Error, A> Raise<Error>.forEachAccumulating(
iterator: Iterator<A>,
combine: (Error, Error) -> Error,
@BuilderInference block: RaiseAccumulate<Error>.(A) -> Unit
): Unit = forEachAccumulating(iterator, combine, block, block)

@PublishedApi @JvmName("forEachAccumulatingDouble")
internal inline fun <Error, A> Raise<Error>.forEachAccumulating(
iterator: Iterator<A>,
combine: (Error, Error) -> Error,
@BuilderInference blockUntilError: RaiseAccumulate<Error>.(A) -> Unit,
@BuilderInference blockAfterError: RaiseAccumulate<Error>.(A) -> Unit
) {
var error: Any? = EmptyValue
var errorArose = false
serras marked this conversation as resolved.
Show resolved Hide resolved
for (item in iterator) {
recover({ block(RaiseAccumulate(this), item) }) { errors ->
recover({
if (errorArose) blockAfterError(RaiseAccumulate(this), item)
else blockUntilError(RaiseAccumulate(this), item)
}) { errors ->
errorArose = true
error = combine(error, errors.reduce(combine), combine)
}
}
Expand All @@ -547,10 +559,26 @@ public inline fun <Error, A> Raise<NonEmptyList<Error>>.forEachAccumulating(
public inline fun <Error, A> Raise<NonEmptyList<Error>>.forEachAccumulating(
iterator: Iterator<A>,
@BuilderInference block: RaiseAccumulate<Error>.(A) -> Unit
): Unit = forEachAccumulating(iterator, block, block)

/**
* Allows to change what to do once the first error is raised.
* Used to provide more performant [mapOrAccumulate].
*/
@PublishedApi @JvmName("forEachAccumulatingDouble")
serras marked this conversation as resolved.
Show resolved Hide resolved
internal inline fun <Error, A> Raise<NonEmptyList<Error>>.forEachAccumulating(
iterator: Iterator<A>,
@BuilderInference blockUntilError: RaiseAccumulate<Error>.(A) -> Unit,
@BuilderInference blockAfterError: RaiseAccumulate<Error>.(A) -> Unit
Copy link
Member

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: Do we really need two lambdas here? 🤔

Or was that for JvmName? Otherwise I am also fine using forEachAccumulatingImpl or something. I've started doing that in my own projects, similar how to KotlinX Coroutines does this with produceImpl. flattenMergeImpl, etc.

Copy link
Member Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

We do, because we still need to call transform on each value, just not add the result to the list.

Copy link
Collaborator

@kyay10 kyay10 Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Why not have:

  @BuilderInference block: RaiseAccumulate<Error>.(A) -> B,
  add: (A, B) -> Unit

and simply we only call add if we don't have any errors.
We could also just have a Boolean parameter instead, with similar effects.

I really also dislike the binary-size blowup that can occur because mapOrAccumulate inlines its block in 2 different places.

) {
val error: MutableList<Error> = mutableListOf()
var errorArose = false
Copy link
Collaborator

@kyay10 kyay10 Feb 16, 2024

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Nit: same as above, but with error.isEmpty()

for (item in iterator) {
recover({ block(RaiseAccumulate(this), item) }) {
recover({
if (errorArose) blockAfterError(RaiseAccumulate(this), item)
else blockUntilError(RaiseAccumulate(this), item)
}) {
errorArose = true
error.addAll(it)
}
}
Expand All @@ -570,7 +598,12 @@ public inline fun <Error, A, B> Raise<Error>.mapOrAccumulate(
combine: (Error, Error) -> Error,
@BuilderInference transform: RaiseAccumulate<Error>.(A) -> B
): List<B> = buildList(iterable.collectionSizeOrDefault(10)) {
forEachAccumulating(iterable, combine) { add(transform(it)) }
forEachAccumulating(
iterable.iterator(),
combine,
blockUntilError = { add(transform(it)) },
blockAfterError = { transform(it) }
)
}

/**
Expand All @@ -585,7 +618,11 @@ public inline fun <Error, A, B> Raise<NonEmptyList<Error>>.mapOrAccumulate(
iterable: Iterable<A>,
@BuilderInference transform: RaiseAccumulate<Error>.(A) -> B
): List<B> = buildList(iterable.collectionSizeOrDefault(10)) {
forEachAccumulating(iterable) { add(transform(it)) }
forEachAccumulating(
iterable.iterator(),
blockUntilError = { add(transform(it)) },
blockAfterError = { transform(it) }
)
}

/**
Expand All @@ -601,7 +638,12 @@ public inline fun <Error, A, B> Raise<Error>.mapOrAccumulate(
combine: (Error, Error) -> Error,
@BuilderInference transform: RaiseAccumulate<Error>.(A) -> B
): List<B> = buildList {
forEachAccumulating(sequence, combine) { add(transform(it)) }
forEachAccumulating(
sequence.iterator(),
combine,
blockUntilError = { add(transform(it)) },
blockAfterError = { transform(it) }
)
}

/**
Expand All @@ -616,7 +658,11 @@ public inline fun <Error, A, B> Raise<NonEmptyList<Error>>.mapOrAccumulate(
sequence: Sequence<A>,
@BuilderInference transform: RaiseAccumulate<Error>.(A) -> B
): List<B> = buildList {
forEachAccumulating(sequence) { add(transform(it)) }
forEachAccumulating(
sequence.iterator(),
blockUntilError = { add(transform(it)) },
blockAfterError = { transform(it) }
)
}

/**
Expand Down Expand Up @@ -644,22 +690,37 @@ public inline fun <Error, A, B> Raise<NonEmptyList<Error>>.mapOrAccumulate(
nonEmptySet: NonEmptySet<A>,
@BuilderInference transform: RaiseAccumulate<Error>.(A) -> B
): NonEmptySet<B> = buildSet(nonEmptySet.size) {
forEachAccumulating(nonEmptySet) { add(transform(it)) }
forEachAccumulating(
nonEmptySet.iterator(),
blockUntilError = { add(transform(it)) },
blockAfterError = { transform(it) }
)
}.toNonEmptySetOrNull()!!

@RaiseDSL
public inline fun <K, Error, A, B> Raise<Error>.mapOrAccumulate(
map: Map<K, A>,
combine: (Error, Error) -> Error,
@BuilderInference transform: RaiseAccumulate<Error>.(Map.Entry<K, A>) -> B
): Map<K, B> = buildMap(map.size) {
forEachAccumulating(map.entries, combine) { put(it.key, transform(it)) }
forEachAccumulating(
map.entries.iterator(),
combine,
blockUntilError = { put(it.key, transform(it)) },
blockAfterError = { transform(it) }
)
}

@RaiseDSL
public inline fun <K, Error, A, B> Raise<NonEmptyList<Error>>.mapOrAccumulate(
map: Map<K, A>,
@BuilderInference transform: RaiseAccumulate<Error>.(Map.Entry<K, A>) -> B
): Map<K, B> = buildMap(map.size) {
forEachAccumulating(map.entries) { put(it.key, transform(it)) }
forEachAccumulating(
map.entries.iterator(),
blockUntilError = { put(it.key, transform(it)) },
blockAfterError = { transform(it) }
)
}

/**
Expand Down Expand Up @@ -709,7 +770,7 @@ public open class RaiseAccumulate<Error>(

@RaiseDSL
@JvmName("_mapOrAccumulate")
public inline fun <A, B> mapOrAccumulate(
public fun <A, B> mapOrAccumulate(
kyay10 marked this conversation as resolved.
Show resolved Hide resolved
set: NonEmptySet<A>,
transform: RaiseAccumulate<Error>.(A) -> B
): NonEmptySet<B> = raise.mapOrAccumulate(set, transform)
Expand Down
Loading