Skip to content

Commit

Permalink
feat: improve the performance f multiply by adding matrix type infe…
Browse files Browse the repository at this point in the history
…rencing (#3149)

* added type inference

* added back accidentally removed return statement and made it so that the explicitly defined type is returned at the end

* made sure that mixed types are ignored in the process data types check

* fixed issue with undefined _data for SparseMatrix and linting issues

* simplified syntax and added type inferencing to src/type/matrix/utils and src/function/matrix/dot.js

* shortened the final part of the type inferencing and moved it to matrix creation in multiply

---------

Co-authored-by: Jos de Jong <[email protected]>
  • Loading branch information
RandomGamingDev and josdejong authored Feb 22, 2024
1 parent 9842ad9 commit 207623c
Show file tree
Hide file tree
Showing 12 changed files with 70 additions and 69 deletions.
1 change: 1 addition & 0 deletions AUTHORS
Original file line number Diff line number Diff line change
Expand Up @@ -239,6 +239,7 @@ Brooks Smith <[email protected]>
Alex Edgcomb <[email protected]>
S.Y. Lee <[email protected]>
Hudsxn <[email protected]>
RandomGamingDev <[email protected]>
Rich Martinez <[email protected]>

# Generated by tools/update-authors.js
60 changes: 30 additions & 30 deletions src/function/arithmetic/multiply.js
Original file line number Diff line number Diff line change
Expand Up @@ -110,11 +110,11 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
// a dense
const adata = a._data
const asize = a._size
const adt = a._datatype
const adt = a._datatype || a.getDataType()
// b dense
const bdata = b._data
const bsize = b._size
const bdt = b._datatype
const bdt = b._datatype || b.getDataType()
// rows & columns
const alength = asize[0]
const bcolumns = bsize[1]
Expand All @@ -127,7 +127,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
let mf = multiplyScalar

// process data types
if (adt && bdt && adt === bdt && typeof adt === 'string') {
if (adt && bdt && adt === bdt && typeof adt === 'string' && adt !== 'mixed') {
// datatype
dt = adt
// find signatures that matches (dt, dt)
Expand All @@ -154,7 +154,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
return a.createDenseMatrix({
data: c,
size: [bcolumns],
datatype: dt
datatype: adt === a._datatype && bdt === b._datatype ? dt : undefined
})
}

Expand Down Expand Up @@ -198,10 +198,10 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
// a dense
const adata = a._data
const asize = a._size
const adt = a._datatype
const adt = a._datatype || a.getDataType()
// b dense
const bdata = b._data
const bdt = b._datatype
const bdt = b._datatype || b.getDataType()
// rows & columns
const arows = asize[0]
const acolumns = asize[1]
Expand All @@ -214,7 +214,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
let mf = multiplyScalar

// process data types
if (adt && bdt && adt === bdt && typeof adt === 'string') {
if (adt && bdt && adt === bdt && typeof adt === 'string' && adt !== 'mixed') {
// datatype
dt = adt
// find signatures that matches (dt, dt)
Expand Down Expand Up @@ -243,7 +243,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
return a.createDenseMatrix({
data: c,
size: [arows],
datatype: dt
datatype: adt === a._datatype && bdt === b._datatype ? dt : undefined
})
}

Expand All @@ -255,15 +255,15 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
*
* @return {Matrix} DenseMatrix (MxC)
*/
function _multiplyDenseMatrixDenseMatrix (a, b) {
function _multiplyDenseMatrixDenseMatrix (a, b) { // getDataType()
// a dense
const adata = a._data
const asize = a._size
const adt = a._datatype
const adt = a._datatype || a.getDataType()
// b dense
const bdata = b._data
const bsize = b._size
const bdt = b._datatype
const bdt = b._datatype || b.getDataType()
// rows & columns
const arows = asize[0]
const acolumns = asize[1]
Expand All @@ -277,7 +277,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
let mf = multiplyScalar

// process data types
if (adt && bdt && adt === bdt && typeof adt === 'string') {
if (adt && bdt && adt === bdt && typeof adt === 'string' && adt !== 'mixed' && adt !== 'mixed') {
// datatype
dt = adt
// find signatures that matches (dt, dt)
Expand Down Expand Up @@ -311,7 +311,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
return a.createDenseMatrix({
data: c,
size: [arows, bcolumns],
datatype: dt
datatype: adt === a._datatype && bdt === b._datatype ? dt : undefined
})
}

Expand All @@ -327,13 +327,13 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
// a dense
const adata = a._data
const asize = a._size
const adt = a._datatype
const adt = a._datatype || a.getDataType()
// b sparse
const bvalues = b._values
const bindex = b._index
const bptr = b._ptr
const bsize = b._size
const bdt = b._datatype
const bdt = b._datatype || b._data === undefined ? b._datatype : b.getDataType()
// validate b matrix
if (!bvalues) { throw new Error('Cannot multiply Dense Matrix times Pattern only Matrix') }
// rows & columns
Expand All @@ -352,7 +352,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
let zero = 0

// process data types
if (adt && bdt && adt === bdt && typeof adt === 'string') {
if (adt && bdt && adt === bdt && typeof adt === 'string' && adt !== 'mixed') {
// datatype
dt = adt
// find signatures that matches (dt, dt)
Expand All @@ -373,7 +373,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
index: cindex,
ptr: cptr,
size: [arows, bcolumns],
datatype: dt
datatype: adt === a._datatype && bdt === b._datatype ? dt : undefined
})

// loop b columns
Expand Down Expand Up @@ -437,12 +437,12 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
const avalues = a._values
const aindex = a._index
const aptr = a._ptr
const adt = a._datatype
const adt = a._datatype || a._data === undefined ? a._datatype : a.getDataType()
// validate a matrix
if (!avalues) { throw new Error('Cannot multiply Pattern only Matrix times Dense Matrix') }
// b dense
const bdata = b._data
const bdt = b._datatype
const bdt = b._datatype || b.getDataType()
// rows & columns
const arows = a._size[0]
const brows = b._size[0]
Expand All @@ -463,7 +463,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
let zero = 0

// process data types
if (adt && bdt && adt === bdt && typeof adt === 'string') {
if (adt && bdt && adt === bdt && typeof adt === 'string' && adt !== 'mixed') {
// datatype
dt = adt
// find signatures that matches (dt, dt)
Expand Down Expand Up @@ -516,13 +516,13 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
// update ptr
cptr[1] = cindex.length

// return sparse matrix
// matrix to return
return a.createSparseMatrix({
values: cvalues,
index: cindex,
ptr: cptr,
size: [arows, 1],
datatype: dt
datatype: adt === a._datatype && bdt === b._datatype ? dt : undefined
})
}

Expand All @@ -539,12 +539,12 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
const avalues = a._values
const aindex = a._index
const aptr = a._ptr
const adt = a._datatype
const adt = a._datatype || a._data === undefined ? a._datatype : a.getDataType()
// validate a matrix
if (!avalues) { throw new Error('Cannot multiply Pattern only Matrix times Dense Matrix') }
// b dense
const bdata = b._data
const bdt = b._datatype
const bdt = b._datatype || b.getDataType()
// rows & columns
const arows = a._size[0]
const brows = b._size[0]
Expand All @@ -562,7 +562,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
let zero = 0

// process data types
if (adt && bdt && adt === bdt && typeof adt === 'string') {
if (adt && bdt && adt === bdt && typeof adt === 'string' && adt !== 'mixed') {
// datatype
dt = adt
// find signatures that matches (dt, dt)
Expand All @@ -583,7 +583,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
index: cindex,
ptr: cptr,
size: [arows, bcolumns],
datatype: dt
datatype: adt === a._datatype && bdt === b._datatype ? dt : undefined
})

// workspace
Expand Down Expand Up @@ -650,12 +650,12 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
const avalues = a._values
const aindex = a._index
const aptr = a._ptr
const adt = a._datatype
const adt = a._datatype || a._data === undefined ? a._datatype : a.getDataType()
// b sparse
const bvalues = b._values
const bindex = b._index
const bptr = b._ptr
const bdt = b._datatype
const bdt = b._datatype || b._data === undefined ? b._datatype : b.getDataType()

// rows & columns
const arows = a._size[0]
Expand All @@ -671,7 +671,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
let mf = multiplyScalar

// process data types
if (adt && bdt && adt === bdt && typeof adt === 'string') {
if (adt && bdt && adt === bdt && typeof adt === 'string' && adt !== 'mixed') {
// datatype
dt = adt
// find signatures that matches (dt, dt)
Expand All @@ -689,7 +689,7 @@ export const createMultiply = /* #__PURE__ */ factory(name, dependencies, ({ typ
index: cindex,
ptr: cptr,
size: [arows, bcolumns],
datatype: dt
datatype: adt === a._datatype && bdt === b._datatype ? dt : undefined
})

// workspace
Expand Down
6 changes: 3 additions & 3 deletions src/function/matrix/dot.js
Original file line number Diff line number Diff line change
Expand Up @@ -64,10 +64,10 @@ export const createDot = /* #__PURE__ */ factory(name, dependencies, ({ typed, a
const N = _validateDim(a, b)

const adata = isMatrix(a) ? a._data : a
const adt = isMatrix(a) ? a._datatype : undefined
const adt = isMatrix(a) ? a._datatype || a.getDataType() : undefined

const bdata = isMatrix(b) ? b._data : b
const bdt = isMatrix(b) ? b._datatype : undefined
const bdt = isMatrix(b) ? b._datatype || b.getDataType() : undefined

// are these 2-dimensional column vectors? (as opposed to 1-dimensional vectors)
const aIsColumn = _size(a).length === 2
Expand All @@ -77,7 +77,7 @@ export const createDot = /* #__PURE__ */ factory(name, dependencies, ({ typed, a
let mul = multiplyScalar

// process data types
if (adt && bdt && adt === bdt && typeof adt === 'string') {
if (adt && bdt && adt === bdt && typeof adt === 'string' && adt !== 'mixed') {
const dt = adt
// find signatures that matches (dt, dt)
add = typed.find(addScalar, [dt, dt])
Expand Down
8 changes: 4 additions & 4 deletions src/type/matrix/utils/matAlgo01xDSid.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ export const createMatAlgo01xDSid = /* #__PURE__ */ factory(name, dependencies,
// dense matrix arrays
const adata = denseMatrix._data
const asize = denseMatrix._size
const adt = denseMatrix._datatype
const adt = denseMatrix._datatype || denseMatrix.getDataType()
// sparse matrix arrays
const bvalues = sparseMatrix._values
const bindex = sparseMatrix._index
const bptr = sparseMatrix._ptr
const bsize = sparseMatrix._size
const bdt = sparseMatrix._datatype
const bdt = sparseMatrix._datatype || sparseMatrix._data === undefined ? sparseMatrix._datatype : sparseMatrix.getDataType()

// validate dimensions
if (asize.length !== bsize.length) { throw new DimensionError(asize.length, bsize.length) }
Expand All @@ -50,7 +50,7 @@ export const createMatAlgo01xDSid = /* #__PURE__ */ factory(name, dependencies,
const columns = asize[1]

// process data types
const dt = typeof adt === 'string' && adt === bdt ? adt : undefined
const dt = typeof adt === 'string' && adt !== 'mixed' && adt === bdt ? adt : undefined
// callback function
const cf = dt ? typed.find(callback, [dt, dt]) : callback

Expand Down Expand Up @@ -97,7 +97,7 @@ export const createMatAlgo01xDSid = /* #__PURE__ */ factory(name, dependencies,
return denseMatrix.createDenseMatrix({
data: cdata,
size: [rows, columns],
datatype: dt
datatype: adt === denseMatrix._datatype && bdt === sparseMatrix._datatype ? dt : undefined
})
}
})
8 changes: 4 additions & 4 deletions src/type/matrix/utils/matAlgo02xDS0.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ export const createMatAlgo02xDS0 = /* #__PURE__ */ factory(name, dependencies, (
// dense matrix arrays
const adata = denseMatrix._data
const asize = denseMatrix._size
const adt = denseMatrix._datatype
const adt = denseMatrix._datatype || denseMatrix.getDataType()
// sparse matrix arrays
const bvalues = sparseMatrix._values
const bindex = sparseMatrix._index
const bptr = sparseMatrix._ptr
const bsize = sparseMatrix._size
const bdt = sparseMatrix._datatype
const bdt = sparseMatrix._datatype || sparseMatrix._data === undefined ? sparseMatrix._datatype : sparseMatrix.getDataType()

// validate dimensions
if (asize.length !== bsize.length) { throw new DimensionError(asize.length, bsize.length) }
Expand All @@ -59,7 +59,7 @@ export const createMatAlgo02xDS0 = /* #__PURE__ */ factory(name, dependencies, (
let cf = callback

// process data types
if (typeof adt === 'string' && adt === bdt) {
if (typeof adt === 'string' && adt === bdt && adt !== 'mixed') {
// datatype
dt = adt
// find signature that matches (dt, dt)
Expand Down Expand Up @@ -102,7 +102,7 @@ export const createMatAlgo02xDS0 = /* #__PURE__ */ factory(name, dependencies, (
index: cindex,
ptr: cptr,
size: [rows, columns],
datatype: dt
datatype: adt === denseMatrix._datatype && bdt === sparseMatrix._datatype ? dt : undefined
})
}
})
8 changes: 4 additions & 4 deletions src/type/matrix/utils/matAlgo03xDSf.js
Original file line number Diff line number Diff line change
Expand Up @@ -28,13 +28,13 @@ export const createMatAlgo03xDSf = /* #__PURE__ */ factory(name, dependencies, (
// dense matrix arrays
const adata = denseMatrix._data
const asize = denseMatrix._size
const adt = denseMatrix._datatype
const adt = denseMatrix._datatype || denseMatrix.getDataType()
// sparse matrix arrays
const bvalues = sparseMatrix._values
const bindex = sparseMatrix._index
const bptr = sparseMatrix._ptr
const bsize = sparseMatrix._size
const bdt = sparseMatrix._datatype
const bdt = sparseMatrix._datatype || sparseMatrix._data === undefined ? sparseMatrix._datatype : sparseMatrix.getDataType()

// validate dimensions
if (asize.length !== bsize.length) { throw new DimensionError(asize.length, bsize.length) }
Expand All @@ -57,7 +57,7 @@ export const createMatAlgo03xDSf = /* #__PURE__ */ factory(name, dependencies, (
let cf = callback

// process data types
if (typeof adt === 'string' && adt === bdt) {
if (typeof adt === 'string' && adt === bdt && adt !== 'mixed') {
// datatype
dt = adt
// convert 0 to the same datatype
Expand Down Expand Up @@ -109,7 +109,7 @@ export const createMatAlgo03xDSf = /* #__PURE__ */ factory(name, dependencies, (
return denseMatrix.createDenseMatrix({
data: cdata,
size: [rows, columns],
datatype: dt
datatype: adt === denseMatrix._datatype && bdt === sparseMatrix._datatype ? dt : undefined
})
}
})
Loading

0 comments on commit 207623c

Please sign in to comment.