Skip to content

Commit

Permalink
unroll macros, similar to the peel. (#6840)
Browse files Browse the repository at this point in the history
  • Loading branch information
Unisay authored Feb 13, 2025
1 parent 021b068 commit c9e0eca
Show file tree
Hide file tree
Showing 5 changed files with 153 additions and 39 deletions.
Original file line number Diff line number Diff line change
@@ -0,0 +1,2 @@
({cpu: 4580180
| mem: 22420})
Original file line number Diff line number Diff line change
@@ -0,0 +1 @@
(con integer 10)
19 changes: 19 additions & 0 deletions plutus-tx-plugin/test/Recursion/9.6/length-unrolled.uplc.golden
Original file line number Diff line number Diff line change
@@ -0,0 +1,19 @@
(program
1.1.0
((\s -> s s)
(\s ds ->
case
ds
[ 0
, (\ds xs ->
addInteger
1
(case
xs
[ 0
, (\ds xs ->
addInteger
1
(case
xs
[0, (\ds xs -> addInteger 1 (s s xs))])) ])) ])))
87 changes: 64 additions & 23 deletions plutus-tx-plugin/test/Recursion/Spec.hs
Original file line number Diff line number Diff line change
@@ -1,70 +1,108 @@
{-# LANGUAGE BangPatterns #-}
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DataKinds #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE NegativeLiterals #-}
{-# LANGUAGE NoImplicitPrelude #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE PatternSynonyms #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE ViewPatterns #-}

{-# OPTIONS_GHC -fplugin PlutusTx.Plugin #-}
{-# OPTIONS_GHC -fplugin-opt PlutusTx.Plugin:defer-errors #-}

module Recursion.Spec where

import PlutusTx.Prelude
import Test.Tasty.Extras

import PlutusTx.Code
import PlutusTx.Function (fix)
import PlutusTx.Lift (liftCodeDef)
import PlutusTx.Optimize.SpaceTime (peel)
import PlutusTx.Prelude qualified as PlutusTx
import PlutusTx.Optimize.SpaceTime (peel, unroll)
import PlutusTx.Test
import PlutusTx.TH (compile)

tests :: TestNested
tests =
testNested "Recursion" . pure $
testNestedGhc
[ goldenUPlcReadable "length-direct" compiledLengthDirect
testNested "Recursion"
. pure
$ testNestedGhc
[ -- length function implemented using direct recursion
goldenUPlcReadable
"length-direct"
compiledLengthDirect
, goldenEvalCekCatch
"length-direct"
[compiledLengthDirect `unsafeApplyCode` liftCodeDef [1..10]]
[compiledLengthDirect `unsafeApplyCode` liftCodeDef [1 .. 10]]
, goldenBudget
"length-direct"
(compiledLengthDirect `unsafeApplyCode` liftCodeDef [1..10])
, goldenUPlcReadable "length-fix" compiledLengthFix
(compiledLengthDirect `unsafeApplyCode` liftCodeDef [1 .. 10])
, -- length function implemented using fix
goldenUPlcReadable
"length-fix"
compiledLengthFix
, goldenEvalCekCatch
"length-fix"
[compiledLengthFix `unsafeApplyCode` liftCodeDef [1..10]]
[compiledLengthFix `unsafeApplyCode` liftCodeDef [1 .. 10]]
, goldenBudget
"length-fix"
(compiledLengthFix `unsafeApplyCode` liftCodeDef [1..10])
, goldenUPlcReadable "length-peeled" compiledLengthPeeled
(compiledLengthFix `unsafeApplyCode` liftCodeDef [1 .. 10])
, -- length function implemented using fix
-- with 3 steps "peeled off" before recursing
goldenUPlcReadable
"length-peeled"
compiledLengthPeeled
, goldenEvalCekCatch
"length-peeled"
[compiledLengthPeeled `unsafeApplyCode` liftCodeDef [1..10]]
[compiledLengthPeeled `unsafeApplyCode` liftCodeDef [1 .. 10]]
, goldenBudget
"length-peeled"
(compiledLengthPeeled `unsafeApplyCode` liftCodeDef [1..10])
(compiledLengthPeeled `unsafeApplyCode` liftCodeDef [1 .. 10])
, -- length function implemented using fix
-- with 3 steps "unrolled" per each recursive call
goldenUPlcReadable
"length-unrolled"
compiledLengthUnrolled
, goldenEvalCekCatch
"length-unrolled"
[compiledLengthUnrolled `unsafeApplyCode` liftCodeDef [1 .. 10]]
, goldenBudget
"length-unrolled"
(compiledLengthUnrolled `unsafeApplyCode` liftCodeDef [1 .. 10])
]

lengthDirect :: [Integer] -> Integer
lengthDirect = \case
[] -> 0
(_ : xs) -> 1 PlutusTx.+ lengthDirect xs
[] -> 0
_ : xs -> 1 + lengthDirect xs

lengthFix :: [Integer] -> Integer
lengthFix =
fix
( \f -> \case
[] -> 0
(_ : xs) -> 1 PlutusTx.+ f xs
)
fix \self -> \case
[] -> 0
_ : xs -> 1 + self xs

lengthPeeled :: [Integer] -> Integer
lengthPeeled = $$(peel 3 (\f -> [|| \case [] -> 0; (_ : xs) -> 1 PlutusTx.+ $$f xs ||]))
lengthPeeled =
$$( peel 3 \self ->
[||
\case
[] -> 0
_ : xs -> 1 + $$self xs
||]
)

lengthUnrolled :: [Integer] -> Integer
lengthUnrolled =
$$( unroll 3 \self ->
[||
\case
[] -> 0
_ : xs -> 1 + $$self xs
||]
)

compiledLengthDirect :: CompiledCode ([Integer] -> Integer)
compiledLengthDirect = $$(compile [||lengthDirect||])
Expand All @@ -74,3 +112,6 @@ compiledLengthFix = $$(compile [||lengthFix||])

compiledLengthPeeled :: CompiledCode ([Integer] -> Integer)
compiledLengthPeeled = $$(compile [||lengthPeeled||])

compiledLengthUnrolled :: CompiledCode ([Integer] -> Integer)
compiledLengthUnrolled = $$(compile [||lengthUnrolled||])
83 changes: 67 additions & 16 deletions plutus-tx/src/PlutusTx/Optimize/SpaceTime.hs
Original file line number Diff line number Diff line change
@@ -1,37 +1,88 @@
{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE TemplateHaskell #-}

-- | Utilities for space-time tradeoff, such as recursion unrolling.
module PlutusTx.Optimize.SpaceTime (peel) where
module PlutusTx.Optimize.SpaceTime (peel, unroll) where

import Prelude

import Language.Haskell.TH.Syntax.Compat qualified as TH
import PlutusTx.Function (fix)
import Prelude

{-| Given @n@, and the functional (or algebra) for a recursive function, peel @n@ layers
{-| Given @n@, and the step function for a recursive function, peel @n@ layers
off of the recursion.
For example @peel 2 (\f xs -> case xs of [] -> 0; (_:ys) -> 1 + f ys)@ yields the
equivalence of the following function:
For example @peel 3 (\self -> [[| \case [] -> 0; _ : ys -> 1 + self ys||])@
yields the equivalence of the following function:
@
lengthPeeled :: [a] -> a
lengthPeeled xs = case xs of
[] -> 0
y:ys -> 1 + case ys of
[] -> 0
z:zs -> 1 + case zs of
[] -> 0
w:ws -> 1 + length ws
lengthPeeled :: [a] -> a
lengthPeeled xs =
case xs of -- first recursion step
[] -> 0
_ : ys -> 1 +
case ys of -- second recursion step
[] -> 0
_ : zs -> 1 +
case zs of -- third recursion step
[] -> 0
_ : ws -> 1 +
( fix \self qs -> -- rest of recursion steps in a tight loop
case qs of
[] -> 0
_ : ts -> 1 + self ts
) ws
@
where @length@ is the regular recursive definition.
-}
peel
:: forall a b
. Int
-- ^ How many recursion steps to move outside of the recursion loop.
-> (TH.SpliceQ (a -> b) -> TH.SpliceQ (a -> b))
{- ^ Function that given a continuation splice returns
a splice representing a single recursion step calling this continuation.
-}
-> TH.SpliceQ (a -> b)
peel 0 f = [||fix (\g -> $$(f [||g||]))||]
peel 0 f = [||fix \self -> $$(f [||self||])||]
peel n f
| n > 0 = f (peel (n - 1) f)
| otherwise = error $ "PlutusTx.Optimize.SpaceTime.peel: negative n: " <> show n

{-| Given @n@, and the step function for a recursive function,
unroll recursion @n@ layers at a time
For example @unroll 3 (\self -> [|| \case [] -> 0; _ : ys -> 1 + self ys ||])@
yields the equivalence of the following function:
@
lengthUnrolled :: [a] -> a
lengthUnrolled =
fix \self xs -> -- beginning of the recursion "loop"
case xs of -- first recursion step
[] -> 0
_ : ys -> 1 +
case ys of -- second recursion step
[] -> 0
_ : zs -> 1 +
case zs of -- third recursion step
[] -> 0
_ : ws -> 1 + self ws -- end of the "loop"
@
-}
unroll
:: forall a b
. Int
-- ^ How many recursion steps to perform inside the recursion loop.
-> (TH.SpliceQ (a -> b) -> TH.SpliceQ (a -> b))
{- ^ Function that given a continuation splice returns
a splice representing a single recursion step calling this continuation.
-}
-> TH.SpliceQ (a -> b)
unroll n f = [||fix \self -> $$(nTimes n f [||self||])||]

-- | Apply a function @n@ times to a given value.
nTimes :: Int -> (a -> a) -> (a -> a)
nTimes 0 _ = id
nTimes 1 f = f
nTimes n f = f . nTimes (n - 1) f

0 comments on commit c9e0eca

Please sign in to comment.