diff --git a/core/src/main/scala/cats/data/StateT.scala b/core/src/main/scala/cats/data/StateT.scala index 439b86a6ff..f3a2f2b3f4 100644 --- a/core/src/main/scala/cats/data/StateT.scala +++ b/core/src/main/scala/cats/data/StateT.scala @@ -17,6 +17,15 @@ final class StateT[F[_], S, A](val runF: F[S => F[(S, A)]]) extends Serializable } }) + def flatMapF[B](faf: A => F[B])(implicit F: Monad[F]): StateT[F, S, B] = + StateT(s => + F.flatMap(runF) { fsf => + F.flatMap(fsf(s)) { case (s, a) => + F.map(faf(a))((s, _)) + } + } + ) + def map[B](f: A => B)(implicit F: Monad[F]): StateT[F, S, B] = transform { case (s, a) => (s, f(a)) } diff --git a/tests/src/test/scala/cats/tests/StateTTests.scala b/tests/src/test/scala/cats/tests/StateTTests.scala index 00d01c9550..9e0b8a447e 100644 --- a/tests/src/test/scala/cats/tests/StateTTests.scala +++ b/tests/src/test/scala/cats/tests/StateTTests.scala @@ -40,6 +40,12 @@ class StateTTests extends CatsSuite { } } + test("flatMap and flatMapF consistent") { + forAll { (stateT: StateT[Option, Long, Int], f: Int => Option[Int]) => + stateT.flatMap(a => StateT(s => f(a).map(b => (s, b)))) should === (stateT.flatMapF(f)) + } + } + test("runEmpty, runEmptyS, and runEmptyA consistent"){ forAll { (f: StateT[List, Long, Int]) => (f.runEmptyS zip f.runEmptyA) should === (f.runEmpty)