Skip to content

Commit

Permalink
Add SemigroupK and MonadCombine instances for StateT.
Browse files Browse the repository at this point in the history
  • Loading branch information
peterneyens committed Jun 17, 2016
1 parent 8f94d13 commit 96b6f4d
Show file tree
Hide file tree
Showing 2 changed files with 93 additions and 17 deletions.
48 changes: 38 additions & 10 deletions core/src/main/scala/cats/data/StateT.scala
Original file line number Diff line number Diff line change
Expand Up @@ -135,24 +135,32 @@ object StateT extends StateTInstances {
StateT(s => F.pure((s, a)))
}

private[data] sealed abstract class StateTInstances extends StateTInstances1 {
private[data] sealed trait StateTInstances extends StateTInstances1 {
implicit def catsDataMonadStateForStateT[F[_], S](implicit F0: Monad[F]): MonadState[StateT[F, S, ?], S] =
new StateTMonadState[F, S] { implicit def F = F0 }

implicit def catsDataLiftForStateT[S]: TransLift.Aux[StateT[?[_], S, ?], Applicative] =
new TransLift[StateT[?[_], S, ?]] {
type TC[M[_]] = Applicative[M]

def liftT[M[_]: Applicative, A](ma: M[A]): StateT[M, S, A] = StateT(s => Applicative[M].map(ma)(s -> _))
}

new StateTTransLift[S] {}
}

private[data] sealed abstract class StateTInstances1 {
private[data] sealed trait StateTInstances1 extends StateTInstances2 {
implicit def catsDataMonadRecForStateT[F[_], S](implicit F0: MonadRec[F]): MonadRec[StateT[F, S, ?]] =
new StateTMonadRec[F, S] { implicit def F = F0 }
}

private[data] sealed trait StateTInstances2 extends StateTInstances3 {
implicit def catsDataMonadCombineForStateT[F[_], S](implicit F0: MonadCombine[F]): MonadCombine[StateT[F, S, ?]] =
new StateTMonadCombine[F, S] { implicit def F = F0 }
}

private[data] sealed trait StateTInstances3 {
implicit def catsDataMonadForStateT[F[_], S](implicit F0: Monad[F]): Monad[StateT[F, S, ?]] =
new StateTMonad[F, S] { implicit def F = F0 }

implicit def catsDataSemigroupKForStateT[F[_], S](implicit F0: Monad[F], G0: SemigroupK[F]): SemigroupK[StateT[F, S, ?]] =
new StateTSemigroupK[F, S] { implicit def F = F0; implicit def G = G0 }
}

// To workaround SI-7139 `object State` needs to be defined inside the package object
// together with the type alias.
private[data] abstract class StateFunctions {
Expand Down Expand Up @@ -195,8 +203,7 @@ private[data] sealed trait StateTMonad[F[_], S] extends Monad[StateT[F, S, ?]] {
def flatMap[A, B](fa: StateT[F, S, A])(f: A => StateT[F, S, B]): StateT[F, S, B] =
fa.flatMap(f)

override def map[A, B](fa: StateT[F, S, A])(f: A => B): StateT[F, S, B] =
fa.map(f)
override def map[A, B](fa: StateT[F, S, A])(f: A => B): StateT[F, S, B] = fa.map(f)
}

private[data] sealed trait StateTMonadState[F[_], S] extends MonadState[StateT[F, S, ?], S] with StateTMonad[F, S] {
Expand All @@ -213,3 +220,24 @@ private[data] sealed trait StateTMonadRec[F[_], S] extends MonadRec[StateT[F, S,
case (s, a) => F.map(f(a).run(s)) { case (s, ab) => ab.bimap((s, _), (s, _)) }
})
}

private[data] sealed trait StateTTransLift[S] extends TransLift[StateT[?[_], S, ?]] {
type TC[M[_]] = Applicative[M]

def liftT[M[_]: Applicative, A](ma: M[A]): StateT[M, S, A] = StateT(s => Applicative[M].map(ma)(s -> _))
}

private[data] sealed trait StateTSemigroupK[F[_], S] extends SemigroupK[StateT[F, S, ?]] {
implicit def F: Monad[F]
implicit def G: SemigroupK[F]

def combineK[A](x: StateT[F, S, A], y: StateT[F, S, A]): StateT[F, S, A] =
StateT(s => G.combineK(x.run(s), y.run(s)))
}

private[data] sealed trait StateTMonadCombine[F[_], S] extends MonadCombine[StateT[F, S, ?]] with StateTMonad[F, S] with StateTSemigroupK[F, S] with StateTTransLift[S] {
implicit def F: MonadCombine[F]
override def G: MonadCombine[F] = F

def empty[A]: StateT[F, S, A] = liftT[F, A](F.empty[A])
}
62 changes: 55 additions & 7 deletions tests/src/test/scala/cats/tests/StateTTests.scala
Original file line number Diff line number Diff line change
@@ -1,9 +1,9 @@
package cats
package tests

import cats.kernel.std.tuple._
import cats.laws.discipline.{CartesianTests, MonadRecTests, MonadStateTests, SerializableTests}
import cats.data.{State, StateT}
import cats.kernel.std.tuple._
import cats.laws.discipline._
import cats.laws.discipline.eq._
import cats.laws.discipline.arbitrary._
import org.scalacheck.Arbitrary
Expand Down Expand Up @@ -114,14 +114,62 @@ class StateTTests extends CatsSuite {
got should === (expected)
}


implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[ListWrapper, Int, ?]](StateT.catsDataMonadForStateT(ListWrapper.monad))

{
// F has a Monad
implicit val F = ListWrapper.monad

checkAll("StateT[ListWrapper, Int, Int]", MonadStateTests[StateT[ListWrapper, Int, ?], Int].monadState[Int, Int, Int])
// checkAll("MonadState[StateT[ListWrapper, Int, ?], Int]", SerializableTests.serializable(MonadState[StateT[ListWrapper, Int, ?], Int]))
checkAll("MonadState[StateT[List, Int, ?], Int]", SerializableTests.serializable(MonadState[StateT[List, Int, ?], Int]))

Monad[StateT[ListWrapper, Int, ?]]
FlatMap[StateT[ListWrapper, Int, ?]]
Applicative[StateT[ListWrapper, Int, ?]]
Apply[StateT[ListWrapper, Int, ?]]
Functor[StateT[ListWrapper, Int, ?]]
}

{
// F has a MonadRec
implicit val F = ListWrapper.monadRec

checkAll("StateT[ListWrapper, Int, Int]", MonadRecTests[StateT[ListWrapper, Int, ?]].monadRec[Int, Int, Int])
checkAll("MonadRec[StateT[ListWrapper, Int, ?]]", SerializableTests.serializable(MonadRec[StateT[ListWrapper, Int, ?]]))

Monad[StateT[ListWrapper, Int, ?]]
FlatMap[StateT[ListWrapper, Int, ?]]
Applicative[StateT[ListWrapper, Int, ?]]
Apply[StateT[ListWrapper, Int, ?]]
Functor[StateT[ListWrapper, Int, ?]]
}

{
implicit val iso = CartesianTests.Isomorphisms.invariant[StateT[Option, Int, ?]]
// F has a Monad and a SemigroupK
implicit def F = ListWrapper.monad
implicit def S = ListWrapper.semigroupK

checkAll("StateT[Option, Int, Int]", MonadStateTests[StateT[Option, Int, ?], Int].monadState[Int, Int, Int])
checkAll("MonadState[StateT[Option, Int, ?], Int]", SerializableTests.serializable(MonadState[StateT[Option, Int, ?], Int]))
checkAll("StateT[ListWrapper, Int, Int]", SemigroupKTests[StateT[ListWrapper, Int, ?]].semigroupK[Int])
checkAll("SemigroupK[StateT[ListWrapper, Int, ?]]", SerializableTests.serializable(SemigroupK[StateT[ListWrapper, Int, ?]]))
}

checkAll("StateT[Option, Int, Int]", MonadRecTests[StateT[Option, Int, ?]].monadRec[Int, Int, Int])
checkAll("MonadRec[StateT[Option, Int, ?]]", SerializableTests.serializable(MonadRec[StateT[Option, Int, ?]]))
{
// F has a MonadCombine
implicit def F = ListWrapper.monadCombine

checkAll("StateT[ListWrapper, Int, Int]", MonadCombineTests[StateT[ListWrapper, Int, ?]].monadCombine[Int, Int, Int])
checkAll("MonadCombine[StateT[ListWrapper, Int, ?]]", SerializableTests.serializable(MonadCombine[StateT[ListWrapper, Int, ?]]))

Monad[StateT[ListWrapper, Int, ?]]
FlatMap[StateT[ListWrapper, Int, ?]]
Alternative[StateT[ListWrapper, Int, ?]]
Applicative[StateT[ListWrapper, Int, ?]]
Apply[StateT[ListWrapper, Int, ?]]
Functor[StateT[ListWrapper, Int, ?]]
MonoidK[StateT[ListWrapper, Int, ?]]
SemigroupK[StateT[ListWrapper, Int, ?]]
}

{
Expand Down

0 comments on commit 96b6f4d

Please sign in to comment.