Skip to content

Commit

Permalink
Split blas Ops and rewrites
Browse files Browse the repository at this point in the history
Having Ops and rewrites in the same files was causing circular imports.
  • Loading branch information
dehorsley authored and ricardoV94 committed May 25, 2023
1 parent 86cbde5 commit c655b02
Show file tree
Hide file tree
Showing 8 changed files with 1,025 additions and 671 deletions.
560 changes: 6 additions & 554 deletions pytensor/tensor/blas.py

Large diffs are not rendered by default.

72 changes: 0 additions & 72 deletions pytensor/tensor/blas_c.py
Original file line number Diff line number Diff line change
@@ -1,22 +1,12 @@
from pytensor.configdefaults import config
from pytensor.graph.rewriting.basic import in2out
from pytensor.link.c.op import COp
from pytensor.link.c.params_type import ParamsType
from pytensor.scalar import bool as bool_t
from pytensor.tensor import basic as at
from pytensor.tensor.blas import (
Gemv,
Ger,
blas_header_text,
blas_header_version,
blas_optdb,
gemv_inplace,
gemv_no_inplace,
ger,
ger_destructive,
ldflags,
node_rewriter,
optdb,
)


Expand Down Expand Up @@ -344,23 +334,6 @@ def c_code_cache_version(self):
cger_no_inplace = CGer(False)


@node_rewriter([ger, ger_destructive])
def use_c_ger(fgraph, node):
if not config.blas__ldflags:
return
# Only float32 and float64 are supported for now.
if node.op == ger and node.outputs[0].dtype in ("float32", "float64"):
return [CGer(False)(*node.inputs)]
if node.op == ger_destructive and node.outputs[0].dtype in ("float32", "float64"):
return [CGer(True)(*node.inputs)]


@node_rewriter([CGer(False)])
def make_c_ger_destructive(fgraph, node):
if isinstance(node.op, CGer) and not node.op.destructive:
return [cger_inplace(*node.inputs)]


# ##### ####### #######
# GEMV
# ##### ####### #######
Expand Down Expand Up @@ -697,48 +670,3 @@ def check_force_gemv_init():


check_force_gemv_init._force_init_beta = None


@node_rewriter([gemv_inplace, gemv_no_inplace])
def use_c_gemv(fgraph, node):
if not config.blas__ldflags:
return
# Only float32 and float64 are supported for now.
if node.op == gemv_no_inplace and node.outputs[0].dtype in ("float32", "float64"):
return [cgemv_no_inplace(*node.inputs)]
if node.op == gemv_inplace and node.outputs[0].dtype in ("float32", "float64"):
return [cgemv_inplace(*node.inputs)]


@node_rewriter([CGemv(inplace=False)])
def make_c_gemv_destructive(fgraph, node):
if isinstance(node.op, CGemv) and not node.op.inplace:
inputs = list(node.inputs)
dest = inputs[0]
if (
dest.owner
and isinstance(dest.owner.op, at.AllocEmpty)
and len(fgraph.clients[dest]) > 1
):
inputs[0] = at.AllocEmpty(dest.dtype)(*dest.owner.inputs)

return [cgemv_inplace(*inputs)]


# ##### ####### #######
# Optimizers
# ##### ####### #######

blas_optdb.register(
"use_c_blas", in2out(use_c_ger, use_c_gemv), "fast_run", "c_blas", position=20
)

# this matches the InplaceBlasOpt defined in blas.py
optdb.register(
"c_blas_destructive",
in2out(make_c_ger_destructive, make_c_gemv_destructive, name="c_blas_destructive"),
"fast_run",
"inplace",
"c_blas",
position=70.0,
)
44 changes: 1 addition & 43 deletions pytensor/tensor/blas_scipy.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,16 +4,7 @@

import numpy as np

from pytensor.graph.rewriting.basic import in2out
from pytensor.tensor.blas import (
Ger,
blas_optdb,
ger,
ger_destructive,
have_fblas,
node_rewriter,
optdb,
)
from pytensor.tensor.blas import Ger, have_fblas


if have_fblas:
Expand Down Expand Up @@ -56,36 +47,3 @@ def perform(self, node, inputs, output_storage):

scipy_ger_no_inplace = ScipyGer(False)
scipy_ger_inplace = ScipyGer(True)


@node_rewriter([ger, ger_destructive])
def use_scipy_ger(fgraph, node):
if node.op == ger:
return [scipy_ger_no_inplace(*node.inputs)]


@node_rewriter([scipy_ger_no_inplace])
def make_ger_destructive(fgraph, node):
if node.op == scipy_ger_no_inplace:
return [scipy_ger_inplace(*node.inputs)]


use_scipy_blas = in2out(use_scipy_ger)
make_scipy_blas_destructive = in2out(make_ger_destructive)

if have_fblas:
# scipy_blas is scheduled in the blas_optdb very late, because scipy sortof
# sucks, but it is almost always present.
# C implementations should be scheduled earlier than this, so that they take
# precedence. Once the original Ger is replaced, then these optimizations
# have no effect.
blas_optdb.register("scipy_blas", use_scipy_blas, "fast_run", position=100)

# this matches the InplaceBlasOpt defined in blas.py
optdb.register(
"make_scipy_blas_destructive",
make_scipy_blas_destructive,
"fast_run",
"inplace",
position=70.0,
)
3 changes: 3 additions & 0 deletions pytensor/tensor/rewriting/__init__.py
Original file line number Diff line number Diff line change
@@ -1,4 +1,7 @@
import pytensor.tensor.rewriting.basic
import pytensor.tensor.rewriting.blas
import pytensor.tensor.rewriting.blas_c
import pytensor.tensor.rewriting.blas_scipy
import pytensor.tensor.rewriting.elemwise
import pytensor.tensor.rewriting.extra_ops

Expand Down
Loading

0 comments on commit c655b02

Please sign in to comment.