Skip to content

Commit

Permalink
vector package should return generic type in computations
Browse files Browse the repository at this point in the history
  • Loading branch information
jairad26 committed Nov 27, 2024
1 parent e5abb35 commit 74ea43d
Show file tree
Hide file tree
Showing 3 changed files with 18 additions and 27 deletions.
18 changes: 9 additions & 9 deletions sdk/assemblyscript/src/assembly/vectors.ts
Original file line number Diff line number Diff line change
Expand Up @@ -237,9 +237,9 @@ export function dot<T extends number>(a: T[], b: T[]): T {
* @param a: The vector
* @returns: The magnitude of the vector
*/
export function magnitude<T extends number>(a: T[]): f64 {
export function magnitude<T extends number>(a: T[]): T {
checkValidArray(a);
return sqrt<f64>(dot(a, a));
return sqrt<T>(dot(a, a));
}

/**
Expand All @@ -248,12 +248,12 @@ export function magnitude<T extends number>(a: T[]): f64 {
* @param b: The second vector
* @returns: The cross product of the two vectors
*/
export function normalize<T extends number>(a: T[]): f64[] {
export function normalize<T extends number>(a: T[]): T[] {
checkValidArray(a);
const magnitudeValue = magnitude(a);
const result: f64[] = new Array<f64>(a.length);
const result: T[] = new Array<T>(a.length);
for (let i = 0; i < a.length; i++) {
result[i] = (a[i] as f64) / magnitudeValue;
result[i] = ((a[i] as T) / magnitudeValue) as T;
}
return result;
}
Expand Down Expand Up @@ -294,9 +294,9 @@ export function product<T extends number>(a: T[]): T {
* @param a: The vector
* @returns: The mean of the vector
*/
export function mean<T extends number>(a: T[]): f64 {
export function mean<T extends number>(a: T[]): T {
checkValidArray(a);
return f64(sum(a)) / f64(a.length);
return (sum(a) / a.length) as T;
}

/**
Expand Down Expand Up @@ -371,7 +371,7 @@ export function euclidianDistance<T extends number>(a: T[], b: T[]): f64 {
checkValidArray(a);
let sum: number = 0;
for (let i = 0; i < a.length; i++) {
sum += f64(a[i] - b[i]) ** 2;
sum += ((a[i] - b[i]) ** 2) as T;
}
return sqrt<f64>(sum);
return sqrt<T>(sum as T);
}
25 changes: 8 additions & 17 deletions sdk/go/pkg/vectors/vectors.go
Original file line number Diff line number Diff line change
Expand Up @@ -128,14 +128,14 @@ func Dot[T constraints.Integer | constraints.Float](a, b []T) T {
}

// Magnitude computes the magnitude of a vector.
func Magnitude[T constraints.Integer | constraints.Float](a []T) float64 {
return math.Sqrt(float64(Dot(a, a)))
func Magnitude[T constraints.Integer | constraints.Float](a []T) T {
return T(math.Sqrt(float64(Dot(a, a))))
}

// Normalize normalizes a vector to have a magnitude of 1.
func Normalize[T constraints.Integer | constraints.Float](a []T) []float64 {
func Normalize[T constraints.Integer | constraints.Float](a []T) []T {
mag := Magnitude(a)
return DivideNumber(convertToFloat64Slice(a), mag)
return DivideNumber(a, mag)
}

// Sum computes the sum of all elements in a vector.
Expand All @@ -157,9 +157,9 @@ func Product[T constraints.Integer | constraints.Float](a []T) T {
}

// func Mean computes the mean of a vector.
func Mean[T constraints.Integer | constraints.Float](a []T) float64 {
func Mean[T constraints.Integer | constraints.Float](a []T) T {
assertNonEmpty(a)
return float64(Sum(a)) / float64(len(a))
return T(float64(Sum(a)) / float64(len(a)))
}

// Min computes the minimum element in a vector.
Expand Down Expand Up @@ -208,13 +208,13 @@ func AbsInPlace[T constraints.Integer | constraints.Float](a []T) {
}

// EuclidianDistance computes the Euclidian distance between two vectors.
func EuclidianDistance[T constraints.Integer | constraints.Float](a, b []T) float64 {
func EuclidianDistance[T constraints.Integer | constraints.Float](a, b []T) T {
assertEqualLength(a, b)
var result float64 = 0
for i := 0; i < len(a); i++ {
result += math.Pow(float64(a[i]-b[i]), 2)
}
return math.Sqrt(result)
return T(math.Sqrt(result))
}

func assertEqualLength[T constraints.Integer | constraints.Float](a, b []T) {
Expand All @@ -223,15 +223,6 @@ func assertEqualLength[T constraints.Integer | constraints.Float](a, b []T) {
}
}

// convertToFloat64Slice converts a slice of type []T to type []float64.
func convertToFloat64Slice[T constraints.Integer | constraints.Float](a []T) []float64 {
result := make([]float64, len(a))
for i := range a {
result[i] = float64(a[i])
}
return result
}

func assertNonEmpty[T constraints.Integer | constraints.Float](a []T) {
if len(a) == 0 {
panic("vector must be non-empty")
Expand Down
2 changes: 1 addition & 1 deletion sdk/go/pkg/vectors/vectors_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -243,7 +243,7 @@ func TestProduct(t *testing.T) {

func TestMean(t *testing.T) {
a := []uint8{1, 2, 3}
expected := 2.0 // (1 + 2 + 3) / 3
expected := uint8(2) // (1 + 2 + 3) / 3

result := Mean(a)

Expand Down

0 comments on commit 74ea43d

Please sign in to comment.