Skip to content

Commit

Permalink
Remove duplicated BLAS rewriting code
Browse files Browse the repository at this point in the history
Accidentally introduced in c655b02

Also move tests to the rewriting test file
  • Loading branch information
ricardoV94 committed Jan 23, 2025
1 parent a0fe30d commit c1e4bb0
Show file tree
Hide file tree
Showing 3 changed files with 165 additions and 466 deletions.
321 changes: 1 addition & 320 deletions pytensor/tensor/blas.py
Original file line number Diff line number Diff line change
Expand Up @@ -79,7 +79,6 @@
import logging
import os
import shlex
import time
from pathlib import Path

import numpy as np
Expand All @@ -103,10 +102,8 @@
from pytensor.tensor import basic as ptb
from pytensor.tensor.basic import expand_dims
from pytensor.tensor.blas_headers import blas_header_text, blas_header_version
from pytensor.tensor.elemwise import DimShuffle
from pytensor.tensor.math import add, mul, neg, sub, variadic_add
from pytensor.tensor.shape import shape_padright, specify_broadcastable
from pytensor.tensor.type import DenseTensorType, TensorType, integer_dtypes, tensor
from pytensor.tensor.type import DenseTensorType, tensor


_logger = logging.getLogger("pytensor.tensor.blas")
Expand Down Expand Up @@ -1148,322 +1145,6 @@ def c_code_cache_version(self):
pprint.assign(gemm_no_inplace, FunctionPrinter(["gemm_no_inplace"]))


def res_is_a(fgraph, var, op, maxclients=None):
if maxclients is not None and var in fgraph.clients:
retval = len(fgraph.get_clients(var)) <= maxclients
else:
retval = True

return var.owner and var.owner.op == op and retval


def _as_scalar(res, dtype=None):
"""Return ``None`` or a `TensorVariable` of float type"""
if dtype is None:
dtype = config.floatX
if all(s == 1 for s in res.type.shape):
while res.owner and isinstance(res.owner.op, DimShuffle):
res = res.owner.inputs[0]
# may still have some number of True's
if res.type.ndim > 0:
rval = res.dimshuffle()
else:
rval = res
if rval.type.dtype in integer_dtypes:
# We check that the upcast of res and dtype won't change dtype.
# If dtype is float64, we will cast int64 to float64.
# This is valid when res is a scalar used as input to a dot22
# as the cast of the scalar can be done before or after the dot22
# and this will give the same result.
if pytensor.scalar.upcast(res.dtype, dtype) == dtype:
return ptb.cast(rval, dtype)
else:
return None

return rval


def _is_real_matrix(res):
return (
res.type.dtype in ("float16", "float32", "float64")
and res.type.ndim == 2
and res.type.shape[0] != 1
and res.type.shape[1] != 1
) # cope with tuple vs. list


def _is_real_vector(res):
return (
res.type.dtype in ("float16", "float32", "float64")
and res.type.ndim == 1
and res.type.shape[0] != 1
)


def _beta_L_plus_alpha_M(fgraph, beta, L, alpha, M, recurse_flip=True):
# print 'BETA L + ALPHA M', beta, L, alpha, M, recurse_flip
# EXPRESSION: (beta * L) + (alpha * M)

# we've already checked the client counts, now just make the type check.
# if res_is_a(M, _dot22, 1):
if M.owner and M.owner.op == _dot22:
Ml, Mr = M.owner.inputs
rval = [gemm_no_inplace(L, alpha, Ml, Mr, beta)]
return rval, M

# it also might be the case that there is a dimshuffle between the +
# and the dot22. local_dot_to_dot22 in particular will put in such things.
if (
M.owner
and isinstance(M.owner.op, DimShuffle)
and M.owner.inputs[0].owner
and isinstance(M.owner.inputs[0].owner.op, Dot22)
):
MM = M.owner.inputs[0]
if M.owner.op.new_order == (0,):
# it is making a column MM into a vector
MMl, MMr = MM.owner.inputs
g = gemm_no_inplace(L.dimshuffle(0, "x"), alpha, MMl, MMr, beta)
rval = [g.dimshuffle(0)]
return rval, MM
if M.owner.op.new_order == (1,):
# it is making a row MM into a vector
MMl, MMr = MM.owner.inputs
g = gemm_no_inplace(L.dimshuffle("x", 0), alpha, MMl, MMr, beta)
rval = [g.dimshuffle(1)]
return rval, MM
if len(M.owner.op.new_order) == 0:
# it is making a row MM into a vector
MMl, MMr = MM.owner.inputs
g = gemm_no_inplace(L.dimshuffle("x", "x"), alpha, MMl, MMr, beta)
rval = [g.dimshuffle()]
return rval, MM

if recurse_flip:
return _beta_L_plus_alpha_M(fgraph, alpha, M, beta, L, recurse_flip=False)
else:
return False, False


def _gemm_canonicalize(fgraph, r, scale, rval, maxclients):
# Tries to interpret node as a sum of scalars * (vectors or matrices)
def scaled(thing):
if scale == 1:
return thing
if scale == -1 and thing.type.dtype != "bool":
return -thing
else:
return scale * thing

if not isinstance(r.type, TensorType):
return None

if (r.type.ndim not in (1, 2)) or r.type.dtype not in (
"float16",
"float32",
"float64",
"complex64",
"complex128",
):
rval.append(scaled(r))
return rval

if maxclients and len(fgraph.clients[r]) > maxclients:
rval.append((scale, r))
return rval

if r.owner and r.owner.op == sub:
_gemm_canonicalize(fgraph, r.owner.inputs[0], scale, rval, 1)
_gemm_canonicalize(fgraph, r.owner.inputs[1], -scale, rval, 1)

elif r.owner and r.owner.op == add:
for i in r.owner.inputs:
_gemm_canonicalize(fgraph, i, scale, rval, 1)

elif r.owner and r.owner.op == neg:
_gemm_canonicalize(fgraph, r.owner.inputs[0], -scale, rval, 1)

elif r.owner and r.owner.op == mul:
scalars = []
vectors = []
matrices = []
for i in r.owner.inputs:
if all(s == 1 for s in i.type.shape):
while i.owner and isinstance(i.owner.op, DimShuffle):
i = i.owner.inputs[0]
if i.type.ndim > 0:
scalars.append(i.dimshuffle())
else:
scalars.append(i)
elif _is_real_vector(i):
vectors.append(i)
elif _is_real_matrix(i):
matrices.append(i)
else:
# just put the original arguments as in the base case
rval.append((scale, r))
return rval
if len(matrices) == 1:
assert len(vectors) == 0
m = matrices[0]
if len(scalars) == 0:
_gemm_canonicalize(fgraph, m, scale, rval, 1)
elif len(scalars) == 1:
_gemm_canonicalize(fgraph, m, scaled(scalars[0]), rval, 1)
else:
_gemm_canonicalize(
fgraph, m, mul(scaled(scalars[0]), *scalars[1:]), rval, 1
)
elif len(vectors) == 1:
assert len(matrices) == 0
v = vectors[0]
if len(scalars) == 0:
_gemm_canonicalize(fgraph, v, scale, rval, 1)
elif len(scalars) == 1:
_gemm_canonicalize(fgraph, v, scaled(scalars[0]), rval, 1)
else:
_gemm_canonicalize(
fgraph, v, mul(scaled(scalars[0]), *scalars[1:]), rval, 1
)
else: # lets not open this up
rval.append((scale, r))
else:
rval.append((scale, r))
return rval


def _factor_canonicalized(lst):
# remove duplicates from canonicalized list

# we only delete out of the right end of the list,
# once i has touched a list element, it is permantent
lst = list(lst)
# print 'FACTOR', lst
# for t in lst:
# if not isinstance(t, (list, tuple)):
# t = (t,)
# for e in t:
# try:
# pytensor.printing.debugprint(e)
# except TypeError:
# print e, type(e)
i = 0
while i < len(lst) - 1:
try:
s_i, M_i = lst[i]
except Exception:
i += 1
continue

j = i + 1
while j < len(lst):
try:
s_j, M_j = lst[j]
except Exception:
j += 1
continue

if M_i is M_j:
s_i = s_i + s_j
lst[i] = (s_i, M_i)
del lst[j]
else:
j += 1
i += 1
return lst


def _gemm_from_factored_list(fgraph, lst):
"""
Returns None, or a list to replace node.outputs.
"""
lst2 = []
# Remove the tuple that can't be cast correctly.
# This can happen when we try to cast a complex to a real
for sM in lst:
# Make every pair in list have matching dtypes
# sM can be a tuple of 2 elements or an PyTensor variable.
if isinstance(sM, tuple):
sm0, sm1 = sM
sm0 = ptb.as_tensor_variable(sm0)
if pytensor.scalar.upcast(sm0.dtype, sm1.dtype) == sm1.dtype:
lst2.append((ptb.cast(sm0, sm1.dtype), sM[1]))

lst = lst2

def item_to_var(t):
try:
s, M = t
except Exception:
return t
if s == 1:
return M
if s == -1:
return -M
return s * M

# Try every pair in the sM_list, trying to turn it into a gemm operation
for i in range(len(lst) - 1):
s_i, M_i = lst[i]

for j in range(i + 1, len(lst)):
s_j, M_j = lst[j]

if not M_j.type.in_same_class(M_i.type):
continue

# print 'TRYING', (s_i, M_i, s_j, M_j)

gemm_of_sM_list, old_dot22 = _beta_L_plus_alpha_M(
fgraph, s_i, M_i, s_j, M_j
)
# print 'GOT IT', gemm_of_sM_list
if gemm_of_sM_list:
assert len(gemm_of_sM_list) == 1
add_inputs = [
item_to_var(input) for k, input in enumerate(lst) if k not in (i, j)
]
add_inputs.extend(gemm_of_sM_list)
rval = [variadic_add(*add_inputs)]
return rval, old_dot22


def _gemm_from_node2(fgraph, node):
"""
TODO: In many expressions, there are many ways to turn it into a
gemm. For example dot(a,b) + c + d. This function should return all
of them, so that if one version of gemm causes a cycle in the graph, then
another application of gemm can be tried.
"""
lst = []
t0 = time.perf_counter()
_gemm_canonicalize(fgraph, node.outputs[0], 1.0, lst, 0)
t1 = time.perf_counter()

if len(lst) > 1:
lst = _factor_canonicalized(lst)
t2 = time.perf_counter()
rval = _gemm_from_factored_list(fgraph, lst)
t3 = time.perf_counter()

# It can happen that _factor_canonicalized and
# _gemm_from_factored_list return a node with an incorrect
# type. This happens in particular when one of the scalar
# factors forces the upcast of the whole expression. In that
# case, we simply skip that candidate for Gemm. This was
# discussed in
# http://groups.google.com/group/theano-dev/browse_thread/thread/a3096c82856e3ad5,
# but never made it into a trac ticket.

if rval and rval[0][0].type.in_same_class(node.outputs[0].type):
return rval, t1 - t0, t2 - t1, t3 - t2

return None, t1 - t0, 0, 0


class Dot22(GemmRelated):
"""Compute a matrix-matrix product.
Expand Down
Loading

0 comments on commit c1e4bb0

Please sign in to comment.