diff --git a/core/src/main/scala/cats/Align.scala b/core/src/main/scala/cats/Align.scala index a6b890c681..884bda6ece 100644 --- a/core/src/main/scala/cats/Align.scala +++ b/core/src/main/scala/cats/Align.scala @@ -45,3 +45,9 @@ import cats.data.Ior def padZipWith[A, B, C](fa: F[A], fb: F[B])(f: (Option[A], Option[B]) => C): F[C] = alignWith(fa, fb)(ior => Function.tupled(f)(ior.pad)) } + +object Align { + def semigroup[F[_]: Align, A: Semigroup]: Semigroup[F[A]] = new Semigroup[F[A]] { + def combine(x: F[A], y: F[A]): F[A] = Align[F].alignCombine(x, y) + } +} diff --git a/core/src/main/scala/cats/Apply.scala b/core/src/main/scala/cats/Apply.scala index f618cd8df9..3a43f93a43 100644 --- a/core/src/main/scala/cats/Apply.scala +++ b/core/src/main/scala/cats/Apply.scala @@ -2,6 +2,7 @@ package cats import simulacrum.typeclass import simulacrum.noop +import cats.data.Ior /** * Weaker version of Applicative[F]; has apply but not pure. @@ -225,6 +226,11 @@ object Apply { */ def semigroup[F[_], A](implicit f: Apply[F], sg: Semigroup[A]): Semigroup[F[A]] = new ApplySemigroup[F, A](f, sg) + + def align[F[_]: Apply]: Align[F] = new Align[F] { + def align[A, B](fa: F[A], fb: F[B]): F[Ior[A, B]] = Apply[F].map2(fa, fb)(Ior.both) + def functor: Functor[F] = Apply[F] + } } private[cats] class ApplySemigroup[F[_], A](f: Apply[F], sg: Semigroup[A]) extends Semigroup[F[A]] { diff --git a/core/src/main/scala/cats/SemigroupK.scala b/core/src/main/scala/cats/SemigroupK.scala index f3ca1b5965..4aba7ba023 100644 --- a/core/src/main/scala/cats/SemigroupK.scala +++ b/core/src/main/scala/cats/SemigroupK.scala @@ -1,6 +1,7 @@ package cats import simulacrum.typeclass +import cats.data.Ior /** * SemigroupK is a universal semigroup which operates on kinds. @@ -68,3 +69,11 @@ import simulacrum.typeclass val F = self } } + +object SemigroupK { + def align[F[_]: SemigroupK: Functor]: Align[F] = new Align[F] { + def align[A, B](fa: F[A], fb: F[B]): F[Ior[A, B]] = + SemigroupK[F].combineK(Functor[F].map(fa)(Ior.left), Functor[F].map(fb)(Ior.right)) + def functor: Functor[F] = Functor[F] + } +} diff --git a/core/src/main/scala/cats/data/Const.scala b/core/src/main/scala/cats/data/Const.scala index 90a637b116..c8e1e6aba4 100644 --- a/core/src/main/scala/cats/data/Const.scala +++ b/core/src/main/scala/cats/data/Const.scala @@ -74,6 +74,12 @@ sealed abstract private[data] class ConstInstances extends ConstInstances0 { x.compare(y) } + implicit def catsDataAlignForConst[A: Semigroup]: Align[Const[A, *]] = new Align[Const[A, *]] { + def align[B, C](fa: Const[A, B], fb: Const[A, C]): Const[A, Ior[B, C]] = + Const(Semigroup[A].combine(fa.getConst, fb.getConst)) + def functor: Functor[Const[A, *]] = catsDataFunctorForConst + } + implicit def catsDataShowForConst[A: Show, B]: Show[Const[A, B]] = new Show[Const[A, B]] { def show(f: Const[A, B]): String = f.show } diff --git a/core/src/main/scala/cats/data/Validated.scala b/core/src/main/scala/cats/data/Validated.scala index 7805497b7a..ee0ffa3214 100644 --- a/core/src/main/scala/cats/data/Validated.scala +++ b/core/src/main/scala/cats/data/Validated.scala @@ -371,6 +371,27 @@ sealed abstract private[data] class ValidatedInstances extends ValidatedInstance } } + implicit def catsDataAlignForValidated[E: Semigroup]: Align[Validated[E, *]] = + new Align[Validated[E, *]] { + def functor: Functor[Validated[E, *]] = catsDataTraverseFunctorForValidated + def align[A, B](fa: Validated[E, A], fb: Validated[E, B]): Validated[E, Ior[A, B]] = + alignWith(fa, fb)(identity) + + override def alignWith[A, B, C](fa: Validated[E, A], fb: Validated[E, B])(f: Ior[A, B] => C): Validated[E, C] = + fa match { + case Invalid(e) => + fb match { + case Invalid(e2) => Invalid(Semigroup[E].combine(e, e2)) + case Valid(b) => Valid(f(Ior.right(b))) + } + case Valid(a) => + fb match { + case Invalid(e) => Valid(f(Ior.left(a))) + case Valid(b) => Valid(f(Ior.both(a, b))) + } + } + } + implicit def catsDataMonoidForValidated[A, B](implicit A: Semigroup[A], B: Monoid[B]): Monoid[Validated[A, B]] = new Monoid[Validated[A, B]] { def empty: Validated[A, B] = Valid(B.empty) diff --git a/laws/src/main/scala/cats/laws/AlignLaws.scala b/laws/src/main/scala/cats/laws/AlignLaws.scala index 1e81955206..590a11fb79 100644 --- a/laws/src/main/scala/cats/laws/AlignLaws.scala +++ b/laws/src/main/scala/cats/laws/AlignLaws.scala @@ -18,10 +18,7 @@ trait AlignLaws[F[_]] { def alignAssociativity[A, B, C](fa: F[A], fb: F[B], fc: F[C]): IsEq[F[Ior[Ior[A, B], C]]] = fa.align(fb).align(fc) <-> fa.align(fb.align(fc)).map(assoc) - def alignSelfBoth[A](fa: F[A]): IsEq[F[A Ior A]] = - fa.align(fa) <-> fa.map(a => Ior.both(a, a)) - - def alignHomomorphism[A, B, C, D](fa: F[A], fb: F[B], f: A => C, g: B => D): IsEq[F[C Ior D]] = + def alignHomomorphism[A, B, C, D](fa: F[A], fb: F[B], f: A => C, g: B => D): IsEq[F[Ior[C, D]]] = fa.map(f).align(fb.map(g)) <-> fa.align(fb).map(_.bimap(f, g)) def alignWithConsistent[A, B, C](fa: F[A], fb: F[B], f: A Ior B => C): IsEq[F[C]] = diff --git a/laws/src/main/scala/cats/laws/discipline/AlignTests.scala b/laws/src/main/scala/cats/laws/discipline/AlignTests.scala index 80f0dd13ed..8d35bb3acf 100644 --- a/laws/src/main/scala/cats/laws/discipline/AlignTests.scala +++ b/laws/src/main/scala/cats/laws/discipline/AlignTests.scala @@ -32,7 +32,6 @@ trait AlignTests[F[_]] extends Laws { new DefaultRuleSet(name = "align", parent = None, "align associativity" -> forAll(laws.alignAssociativity[A, B, C] _), - "align self both" -> forAll(laws.alignSelfBoth[A] _), "align homomorphism" -> forAll { (fa: F[A], fb: F[B], f: A => C, g: B => D) => laws.alignHomomorphism[A, B, C, D](fa, fb, f, g) }, diff --git a/tests/src/test/scala/cats/tests/AlignSuite.scala b/tests/src/test/scala/cats/tests/AlignSuite.scala new file mode 100644 index 0000000000..167bc0e760 --- /dev/null +++ b/tests/src/test/scala/cats/tests/AlignSuite.scala @@ -0,0 +1,14 @@ +package cats.tests + +import cats.Align +import cats.kernel.laws.discipline.SemigroupTests + +class AlignSuite extends CatsSuite { + { + val optionSemigroup = Align.semigroup[Option, Int] + checkAll("Align[Option].semigroup", SemigroupTests[Option[Int]](optionSemigroup).semigroup) + + val listSemigroup = Align.semigroup[List, String] + checkAll("Align[List].semigroup", SemigroupTests[List[String]](listSemigroup).semigroup) + } +} diff --git a/tests/src/test/scala/cats/tests/ApplicativeSuite.scala b/tests/src/test/scala/cats/tests/ApplicativeSuite.scala index e500697cfb..d85e050038 100644 --- a/tests/src/test/scala/cats/tests/ApplicativeSuite.scala +++ b/tests/src/test/scala/cats/tests/ApplicativeSuite.scala @@ -6,6 +6,7 @@ import cats.kernel.laws.discipline.{MonoidTests, SemigroupTests} import cats.data.{Const, Validated} import cats.laws.discipline.arbitrary._ import cats.laws.discipline.CoflatMapTests +import cats.laws.discipline.AlignTests class ApplicativeSuite extends CatsSuite { @@ -59,6 +60,15 @@ class ApplicativeSuite extends CatsSuite { implicit val constCoflatMap = Applicative.coflatMap[Const[String, *]] checkAll("Applicative[Const].coflatMap", CoflatMapTests[Const[String, *]].coflatMap[String, String, String]) + + implicit val listwrapperAlign = Apply.align[ListWrapper] + checkAll("Apply[ListWrapper].align", AlignTests[ListWrapper].align[Int, Int, Int, Int]) + + implicit val validatedAlign = Apply.align[Validated[String, *]] + checkAll("Apply[Validated].align", AlignTests[Validated[String, *]].align[Int, Int, Int, Int]) + + implicit val constAlign = Apply.align[Const[String, *]] + checkAll("Apply[Const].align", AlignTests[Const[String, *]].align[Int, Int, Int, Int]) } } diff --git a/tests/src/test/scala/cats/tests/ConstSuite.scala b/tests/src/test/scala/cats/tests/ConstSuite.scala index d7e5477ce9..4427328259 100644 --- a/tests/src/test/scala/cats/tests/ConstSuite.scala +++ b/tests/src/test/scala/cats/tests/ConstSuite.scala @@ -32,6 +32,9 @@ class ConstSuite extends CatsSuite { checkAll("Const[String, Int]", TraverseFilterTests[Const[String, *]].traverseFilter[Int, Int, Int]) checkAll("TraverseFilter[Const[String, *]]", SerializableTests.serializable(TraverseFilter[Const[String, *]])) + checkAll("Const[String, Int]", AlignTests[Const[String, *]].align[Int, Int, Int, Int]) + checkAll("Align[Const[String, *]]", SerializableTests.serializable(Align[Const[String, *]])) + // Get Apply[Const[C : Semigroup, *]], not Applicative[Const[C : Monoid, *]] { implicit def nonEmptyListSemigroup[A]: Semigroup[NonEmptyList[A]] = SemigroupK[NonEmptyList].algebra diff --git a/tests/src/test/scala/cats/tests/SemigroupKSuite.scala b/tests/src/test/scala/cats/tests/SemigroupKSuite.scala new file mode 100644 index 0000000000..fa30726f57 --- /dev/null +++ b/tests/src/test/scala/cats/tests/SemigroupKSuite.scala @@ -0,0 +1,20 @@ +package cats.tests + +import cats.SemigroupK +import cats.data.{Chain, Validated} +import cats.laws.discipline.AlignTests +import cats.laws.discipline.arbitrary._ + +class SemigroupKSuite extends CatsSuite { + { + implicit val listwrapperSemigroupK = ListWrapper.alternative + implicit val listwrapperAlign = SemigroupK.align[ListWrapper] + checkAll("SemigroupK[ListWrapper].align", AlignTests[ListWrapper].align[Int, Int, Int, Int]) + + implicit val validatedAlign = SemigroupK.align[Validated[String, *]] + checkAll("SemigroupK[Validated].align", AlignTests[Validated[String, *]].align[Int, Int, Int, Int]) + + implicit val chainAlign = SemigroupK.align[Chain] + checkAll("SemigroupK[Chain].align", AlignTests[Chain].align[Int, Int, Int, Int]) + } +} diff --git a/tests/src/test/scala/cats/tests/ValidatedSuite.scala b/tests/src/test/scala/cats/tests/ValidatedSuite.scala index a7bc31cb29..1a27de9dda 100644 --- a/tests/src/test/scala/cats/tests/ValidatedSuite.scala +++ b/tests/src/test/scala/cats/tests/ValidatedSuite.scala @@ -41,6 +41,9 @@ class ValidatedSuite extends CatsSuite { checkAll("CommutativeApplicative[Validated[Int, *]]", SerializableTests.serializable(CommutativeApplicative[Validated[Int, *]])) + checkAll("Validated[Int, Int]", AlignTests[Validated[Int, *]].align[Int, Int, Int, Int]) + checkAll("Align[Validated[Int, *]]", SerializableTests.serializable(Align[Validated[Int, *]])) + { implicit val L = ListWrapper.semigroup[String] checkAll("Validated[ListWrapper[String], *]", SemigroupKTests[Validated[ListWrapper[String], *]].semigroupK[Int])