Skip to content

Commit

Permalink
speed up DOT
Browse files Browse the repository at this point in the history
  • Loading branch information
stylewarning authored and kartik-s committed Jun 7, 2022
1 parent e0a872e commit d49024b
Show file tree
Hide file tree
Showing 6 changed files with 47 additions and 17 deletions.
1 change: 1 addition & 0 deletions magicl.asd
Original file line number Diff line number Diff line change
Expand Up @@ -57,6 +57,7 @@
(:file "types/complex-single-float")
(:file "types/complex-double-float")
(:file "types/int32")
(:file "types/specialized-vector")
(:file "constructors")
(:file "specialize-constructor")
(:file "polynomial-solver")))
Expand Down
7 changes: 0 additions & 7 deletions src/high-level/types/complex-double-float.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,6 @@
matrix/complex-double-float
vector/complex-double-float)

(defmethod dot ((vector1 vector/complex-double-float) (vector2 vector/complex-double-float))
(assert (cl:= (size vector1) (size vector2))
() "Vectors must have the same size. The first vector is size ~a and the second vector is size ~a."
(size vector1) (size vector2))
(loop :for i :below (size vector1)
:sum (* (tref vector1 i) (conjugate (tref vector2 i)))))

(defmethod =-lisp ((tensor1 tensor/complex-double-float) (tensor2 tensor/complex-double-float) &optional (epsilon *double-comparison-threshold*))
(unless (equal (shape tensor1) (shape tensor2))
(return-from =-lisp nil))
Expand Down
7 changes: 0 additions & 7 deletions src/high-level/types/complex-single-float.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -22,13 +22,6 @@
matrix/complex-single-float
vector/complex-single-float)

(defmethod dot ((vector1 vector/complex-single-float) (vector2 vector/complex-single-float))
(assert (cl:= (size vector1) (size vector2))
() "Vectors must have the same size. The first vector is size ~a and the second vector is size ~a."
(size vector1) (size vector2))
(loop :for i :below (size vector1)
:sum (* (tref vector1 i) (conjugate (tref vector2 i)))))

(defmethod =-lisp ((tensor1 tensor/complex-single-float) (tensor2 tensor/complex-single-float) &optional (epsilon *float-comparison-threshold*))
(unless (equal (shape tensor1) (shape tensor2))
(return-from =-lisp nil))
Expand Down
1 change: 1 addition & 0 deletions src/high-level/types/single-float.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -57,3 +57,4 @@
epsilon)
(return-from =-lisp nil))))
t)

42 changes: 42 additions & 0 deletions src/high-level/types/specialized-vector.lisp
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
;;;; specialized-vector.lisp
;;;;
;;;; Author: Robert Smith

(in-package #:magicl)

;;;; Here we have some functions that specialize on many types

(defmacro define-common-real-vector-methods (vector-class elt-type)
`(progn
(defmethod dot-lisp ((vec1 ,vector-class) (vec2 ,vector-class))
(let ((size1 (size vec1))
(size2 (size vec2)))
(assert (cl:= size1 size2))
(let ((s1 (storage vec1))
(s2 (storage vec2)))
(declare (optimize speed)
(type (simple-array ,elt-type (*)) s1 s2))
(loop :with s :of-type ,elt-type := ,(coerce 0 elt-type)
:for i :of-type alexandria:array-index :below size1
:do (incf s (* (aref s1 i) (aref s2 i)))
:finally (return s)))))))

(defmacro define-common-complex-vector-methods (vector-class elt-type)
`(progn
(defmethod dot-lisp ((vec1 ,vector-class) (vec2 ,vector-class))
(let ((size1 (size vec1))
(size2 (size vec2)))
(assert (cl:= size1 size2))
(let ((s1 (storage vec1))
(s2 (storage vec2)))
(declare (optimize speed)
(type (simple-array ,elt-type (*)) s1 s2))
(loop :with s :of-type ,elt-type := ,(coerce 0 elt-type)
:for i :of-type alexandria:array-index :below size1
:do (incf s (* (aref s1 i) (conjugate (aref s2 i))))
:finally (return s)))))))

(define-common-real-vector-methods vector/single-float single-float)
(define-common-real-vector-methods vector/double-float double-float)
(define-common-complex-vector-methods vector/complex-single-float (complex single-float))
(define-common-complex-vector-methods vector/complex-double-float (complex double-float))
6 changes: 3 additions & 3 deletions src/high-level/vector.lisp
Original file line number Diff line number Diff line change
Expand Up @@ -153,13 +153,13 @@ ELEMENT-TYPE, CAST, COPY-TENSOR, DEEP-COPY-TENSOR, TREF, SETF TREF)"
(assertion (cl:= 1 (length new-value))))
(setf (vector-size vector) (first new-value))))

(defgeneric dot (vector1 vector2)
(:documentation "Compute the dot product of two vectors")
(define-extensible-function (dot dot-lisp) (vector1 vector2)
(:documentation "Compute the dot product of two vectors. For complex vectors, this conjugates the second argument.")
(:method ((vector1 vector) (vector2 vector))
(policy-cond:with-expectations (> speed safety)
((assertion (cl:= (size vector1) (size vector2))))
(loop :for i :below (size vector1)
:sum (* (tref vector1 i) (tref vector2 i))))))
:sum (* (tref vector1 i) (conjugate (tref vector2 i)))))))

(deftype p-norm-type ()
`(or (member :inf :infinity :positive-infinity)
Expand Down

0 comments on commit d49024b

Please sign in to comment.