Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Beer -> Food cookbook example #276

Merged
merged 2 commits into from
Feb 19, 2025
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
2 changes: 1 addition & 1 deletion docs/blog/posts/2024-nov-11-mission.md
Original file line number Diff line number Diff line change
Expand Up @@ -103,7 +103,7 @@ JAX is a great project for automatic differentiation and scientific computing. I

While Kirin draws inspiration from MLIR, xDSL, Julia compiler plugin, and JAXPR, we aim to build a more user-friendly compiler infrastructure for scientists to solve their specific problems. There are nothing fundamentally new in theory, but the combination is new. Here are some key differences:

**Composable Python Lowering**, in our beer-lang example, the kernel decorator `@beer` is just
**Composable Python Lowering**, in our food-lang example, the kernel decorator `@food` is just
a `DialectGroup` object that contains the `Dialect` objects you specified to include for the frontend. The Python syntax just maginally works! This is because Kirin features a composable lowering system that allows you to claim Python syntax from each separate dialect. When combining the dialects together, Kirin will be able to compile each Python syntax to the corresponding IR nodes, e.g

- `func` dialect claims function related syntax: the `ast.FunctionDef`, nested `ast.FunctionDef` (as closures), `ast.Call`, `ast.Return`, etc.
Expand Down
160 changes: 95 additions & 65 deletions docs/cookbook/beer_dialect/analysis.md
Original file line number Diff line number Diff line change
@@ -1,29 +1,29 @@
## Beer price/fee analysis
## Food price/fee analysis

In this section we will discuss on how to perform analysis of a kirin program. We will again use our `beer` dialect example.
In this section we will discuss on how to perform analysis of a kirin program. We will again use our `food` dialect example.

### Goal

Let's Consider the following program
```python
@beer
@food
def main2(x: int):

bud = NewBeer(brand="budlight")
heineken = NewBeer(brand="heineken")
burger = NewFood(type="burger")
salad = NewFood(type="salad")

bud_pints = Pour(bud, 12 + x)
heineken_pints = Pour(heineken, 10 + x)
burger_serving = Cook(burger, 12 + x)
salad_serving = Cook(salad, 10 + x)

Drink(bud_pints)
Drink(heineken_pints)
Puke()
Eat(burger_serving)
Eat(salad_serving)
Nap()

Drink(bud_pints)
Puke()
Eat(burger_serving)
Nap()

Drink(bud_pints)
Puke()
Eat(burger_serving)
Nap()

return x
```
Expand Down Expand Up @@ -103,17 +103,18 @@ Next there are a few more lattice elements we want to define:
```python
@final
@dataclass
class ItemPints(Item): # (1)!
class ItemServing(Item): # (1)!
count: Item
brand: str
type: str

def is_subseteq(self, other: Item) -> bool:
return (
isinstance(other, ItemPints)
isinstance(other, ItemServing)
and self.count == other.count
and self.brand == other.brand
and self.type == other.type
)


@final
@dataclass
class AtLeastXItem(Item): # (2)!
Expand All @@ -125,7 +126,7 @@ class AtLeastXItem(Item): # (2)!

@final
@dataclass
class ConstIntItem(Item):
class ConstIntItem(Item): # (3)!
data: int

def is_subseteq(self, other: Item) -> bool:
Expand All @@ -134,19 +135,17 @@ class ConstIntItem(Item):

@final
@dataclass
class ItemBeer(Item):
brand: str
class ItemFood(Item): # (4)!
type: str

def is_subseteq(self, other: Item) -> bool:
return isinstance(other, ItemBeer) and self.brand == other.brand


return isinstance(other, ItemFood) and self.type == other.type
```

1. `ItemPints` which contain information of the beer brand of `Pints`, as well as the count
1. `ItemServing` which contain information of the kind of food of the `Serving`, as well as the count
2. `AtLeastXItem` which contain information of a constant type result value is a number that is least `x`. The `data` contain the lower-bound
3. `ConstIntItem` which contain concrete number.
4. `ItemBeer` which contain information of `Beer`.
4. `ItemFood` which contains information of `Food`.


### Custom Forward Data Flow Analysis
Expand All @@ -158,14 +157,14 @@ In kirin, the analysis pass is implemented with `AbstractInterpreter` (inspired
Here our analysis want to do two things.

1. Get all the analysis results as dictionary of SSAVAlue to lattice element.
2. Count how many time one puke.
2. Count how many time one naps.

```python
@dataclass
class FeeAnalysis(Forward[latt.Item]): # (1)!
keys = ["beer.fee"] # (2)!
keys = ["food.fee"] # (2)!
lattice = latt.Item
puke_count: int = field(init=False)
nap_count: int = field(init=False)

def initialize(self): # (3)!
"""Initialize the analysis pass.
Expand All @@ -176,11 +175,11 @@ class FeeAnalysis(Forward[latt.Item]): # (1)!
1. Here one is *required* to call the super().initialize() to initialize the analysis pass,
which clear all the previous analysis results and symbol tables.
2. Any additional initialization that belongs to the analysis should also be done here.
For example, in this case, we initialize the puke_count to 0.
For example, in this case, we initialize the nap_count to 0.

"""
super().initialize()
self.puke_count = 0
self.nap_count = 0
return self

def eval_stmt_fallback( # (4)!
Expand All @@ -191,6 +190,37 @@ class FeeAnalysis(Forward[latt.Item]): # (1)!
def run_method(self, method: ir.Method, args: tuple[latt.Item, ...]) -> latt.Item: # (5)!
return self.run_callable(method.code, (self.lattice.bottom(),) + args)

@dataclass
class FeeAnalysis(Forward[Item]): # (1)!
keys = ["food.fee"] # (2)!
lattice = Item
nap_count: int = field(init=False)

def initialize(self): # (3)!
"""Initialize the analysis pass.

The method is called before the analysis pass starts.

Note:
1. Here one is *required* to call the super().initialize() to initialize the analysis pass,
which clear all the previous analysis results and symbol tables.
2. Any additional initialization that belongs to the analysis should also be done here.
For example, in this case, we initialize the nap_count to 0.

"""
super().initialize()
self.nap_count = 0
return self

def eval_stmt_fallback( # (4)!
self, frame: ForwardFrame[Item], stmt: ir.Statement
) -> tuple[Item, ...] | interp.SpecialValue[Item]:
return ()

def run_method(self, method: ir.Method, args: tuple[Item, ...]): # (5)!
return self.run_callable(method.code, (self.lattice.bottom(),) + args)


```

1. Interit from `Forward` with our customize lattice `Item`.
Expand All @@ -205,25 +235,25 @@ Now we want to implement how the statement gets run. This is the same as what we

Note that each dialect can have multiple registered `MethodTable`, distinguished by `key`. The interpreter use `key` to find corresponding `MethodTable`s for each dialect in a dialect group.

Here, we use `key="beer.fee"`
Here, we use `key="food.fee"`

First we need to implement for `Constant` statement in `py.constant` dialect. If its `int`, we return `ConstIntItem` lattice element. If its `Beer`, we return `ItemBeer`.
First we need to implement for `Constant` statement in `py.constant` dialect. If its `int`, we return `ConstIntItem` lattice element. If its `Food`, we return `ItemFood`.

```python
@py.constant.dialect.register(key="beer.fee")
@py.constant.dialect.register(key="food.fee")
class PyConstMethodTable(interp.MethodTable):

@interp.impl(py.constant.Constant)
def const(
self,
interp: FeeAnalysis,
frame: interp.Frame[latt.Item],
frame: interp.Frame[Item],
stmt: py.constant.Constant,
):
if isinstance(stmt.value, int):
return (latt.ConstIntItem(data=stmt.value),)
elif isinstance(stmt.value, Beer):
return (latt.ItemBeer(brand=stmt.value.brand),)
return (ConstIntItem(data=stmt.value),)
elif isinstance(stmt.value, Food):
return (ItemFood(type=stmt.value.type),)

else:
raise exceptions.InterpreterError(
Expand All @@ -234,65 +264,65 @@ class PyConstMethodTable(interp.MethodTable):

Next, since we allow `add` in the program, we also need to let abstract interpreter know how to interprete `binop.Add` statement, which is in `py.binop` dialect.
```python
@binop.dialect.register(key="beer.fee")
@binop.dialect.register(key="food.fee")
class PyBinOpMethodTable(interp.MethodTable):

@interp.impl(binop.Add)
def add(
self,
interp: FeeAnalysis,
frame: interp.Frame[latt.Item],
frame: interp.Frame[Item],
stmt: binop.Add,
):
left = frame.get(stmt.lhs)
right = frame.get(stmt.rhs)

if isinstance(left, latt.AtLeastXItem) or isinstance(right, latt.AtLeastXItem):
out = latt.AtLeastXItem(data=left.data + right.data)
if isinstance(left, AtLeastXItem) or isinstance(right, AtLeastXItem):
out = AtLeastXItem(data=left.data + right.data)
else:
out = latt.ConstIntItem(data=left.data + right.data)
out = ConstIntItem(data=left.data + right.data)

return (out,)
```

Finally, we need implementation for our beer dialect Statements.
Finally, we need implementation for our food dialect Statements.
```python
@dialect.register(key="beer.fee")
class BeerMethodTable(interp.MethodTable):
@dialect.register(key="food.fee")
class FoodMethodTable(interp.MethodTable):

@interp.impl(NewBeer)
def new_beer(
@interp.impl(NewFood)
def new_food(
self,
interp: FeeAnalysis,
frame: interp.Frame[latt.Item],
stmt: NewBeer,
frame: interp.Frame[Item],
stmt: NewFood,
):
return (latt.ItemBeer(brand=stmt.brand),)
return (ItemFood(type=stmt.type),)

@interp.impl(Pour)
def pour(
@interp.impl(Cook)
def cook(
self,
interp: FeeAnalysis,
frame: interp.Frame[latt.Item],
stmt: Pour,
frame: interp.Frame[Item],
stmt: Cook,
):
# Drink depends on the beer type to have different charge:
# food depends on the food type to have different charge:

beer: latt.ItemBeer = frame.get(stmt.beverage)
pint_count: latt.AtLeastXItem | latt.ConstIntItem = frame.get(stmt.amount)
food = frame.get_typed(stmt.target, ItemFood)
serving_count: AtLeastXItem | ConstIntItem = frame.get(stmt.amount)

out = latt.ItemPints(count=pint_count, brand=beer.brand)
out = ItemServing(count=serving_count, type=food.type)

return (out,)

@interp.impl(Puke)
def puke(
@interp.impl(Nap)
def nap(
self,
interp: FeeAnalysis,
frame: interp.Frame[latt.Item],
stmt: Puke,
frame: interp.Frame[Item],
stmt: Nap,
):
interp.puke_count += 1
interp.nap_count += 1
return ()

```
Expand All @@ -303,5 +333,5 @@ class BeerMethodTable(interp.MethodTable):
fee_analysis = FeeAnalysis(main2.dialects)
results, expect = fee_analysis.run_analysis(main2, args=(AtLeastXItem(data=10),))
print(results)
print(fee_analysis.puke_count)
print(fee_analysis.nap_count)
```
22 changes: 11 additions & 11 deletions docs/cookbook/beer_dialect/cf_rewrite.md
Original file line number Diff line number Diff line change
@@ -1,12 +1,12 @@
## Rewrite if-else control flow

In the main page, we introduce a simple `beer` dialect example, and described how to use kirin to define a simple compiler.
In this section, we want to continue with this `beer`, and considering more compilcated rewrite pass
that involving the build-in python dialect if-else control flow.
In the main page, we introduce a simple `food` dialect example, and described how to use kirin to define a simple compiler.
In this section, we want to continue with this exampe, and consider a more compilcated rewrite pass
that involves the build-in python dialect if-else control flow.

### Goal
When one get really really grunk, not only does one puke, we also would make random decision.
Here specifically, We want to rewrite the existing `IfElse` statement defined in the existing `py` dialect into a customize `RandomBranch` statement we defined in our beer dialect.
When one get really really full, not only does one nap, we also would make random decision.
Here specifically, We want to rewrite the existing `IfElse` statement defined in the existing `py` dialect into a customize `RandomBranch` statement we defined in our food dialect.

The execution of `RandomBranch`, as stated in its name, randomly execute a branch each time we run it.

Expand Down Expand Up @@ -44,14 +44,14 @@ unlike a normal `if else` branching statement, it does not execute the branches
it randomly chooses one of the branches to execute. We will implement the execution behavior of this statement in the following.

### Implementation and MethodTable
Recall in the introduction of beer dialect we metioned about `MethodTable`. Now we have defined the statement, we will need to tell interpreter how to interprete this Statement we defined.
Recall in the introduction of food dialect we metioned about `MethodTable`. Now we have defined the statement, we will need to tell interpreter how to interprete this Statement we defined.

Let's find the `BeerMethods` MethodTable that we defined and registered to `beer` dialect previously:
Let's find the `FoodMethods` MethodTable that we defined and registered to `food` dialect previously:
```python
from kirin.interp import Frame, Successor, Interpreter, MethodTable, impl

@dialect.register
class BeerMethods(MethodTable):
class FoodMethods(MethodTable):
...
```

Expand Down Expand Up @@ -120,13 +120,13 @@ from kirin.prelude import basic_no_opt
from kirin.rewrite import Walk, Fixpoint

@dialect_group(basic_no_opt.add(dialect))
def beer(self):
def food(self):

# some initialization if you need it
def run_pass(mt, drunk:bool=False, got_lost: bool=True): # (1)!
def run_pass(mt, hungry:bool=False, got_lost: bool=True): # (1)!

if drunk:
Walk(NewBeerAndPukeOnDrink()).rewrite(mt.code)
Walk(NewFoodAndNap()).rewrite(mt.code)

if got_lost:
Fixpoint(Walk(RandomWalkBranch())).rewrite(mt.code) # (2)!
Expand Down
Loading