From 74ea43db8cdc4e814b1f9d1208b39613223460c0 Mon Sep 17 00:00:00 2001 From: Jai Radhakrishnan <55522316+jairad26@users.noreply.github.com> Date: Wed, 27 Nov 2024 15:13:20 -0800 Subject: [PATCH] vector package should return generic type in computations --- sdk/assemblyscript/src/assembly/vectors.ts | 18 ++++++++-------- sdk/go/pkg/vectors/vectors.go | 25 +++++++--------------- sdk/go/pkg/vectors/vectors_test.go | 2 +- 3 files changed, 18 insertions(+), 27 deletions(-) diff --git a/sdk/assemblyscript/src/assembly/vectors.ts b/sdk/assemblyscript/src/assembly/vectors.ts index 0d48cd10..209cc73b 100644 --- a/sdk/assemblyscript/src/assembly/vectors.ts +++ b/sdk/assemblyscript/src/assembly/vectors.ts @@ -237,9 +237,9 @@ export function dot(a: T[], b: T[]): T { * @param a: The vector * @returns: The magnitude of the vector */ -export function magnitude(a: T[]): f64 { +export function magnitude(a: T[]): T { checkValidArray(a); - return sqrt(dot(a, a)); + return sqrt(dot(a, a)); } /** @@ -248,12 +248,12 @@ export function magnitude(a: T[]): f64 { * @param b: The second vector * @returns: The cross product of the two vectors */ -export function normalize(a: T[]): f64[] { +export function normalize(a: T[]): T[] { checkValidArray(a); const magnitudeValue = magnitude(a); - const result: f64[] = new Array(a.length); + const result: T[] = new Array(a.length); for (let i = 0; i < a.length; i++) { - result[i] = (a[i] as f64) / magnitudeValue; + result[i] = ((a[i] as T) / magnitudeValue) as T; } return result; } @@ -294,9 +294,9 @@ export function product(a: T[]): T { * @param a: The vector * @returns: The mean of the vector */ -export function mean(a: T[]): f64 { +export function mean(a: T[]): T { checkValidArray(a); - return f64(sum(a)) / f64(a.length); + return (sum(a) / a.length) as T; } /** @@ -371,7 +371,7 @@ export function euclidianDistance(a: T[], b: T[]): f64 { checkValidArray(a); let sum: number = 0; for (let i = 0; i < a.length; i++) { - sum += f64(a[i] - b[i]) ** 2; + sum += ((a[i] - b[i]) ** 2) as T; } - return sqrt(sum); + return sqrt(sum as T); } diff --git a/sdk/go/pkg/vectors/vectors.go b/sdk/go/pkg/vectors/vectors.go index 41850e51..4cb07853 100644 --- a/sdk/go/pkg/vectors/vectors.go +++ b/sdk/go/pkg/vectors/vectors.go @@ -128,14 +128,14 @@ func Dot[T constraints.Integer | constraints.Float](a, b []T) T { } // Magnitude computes the magnitude of a vector. -func Magnitude[T constraints.Integer | constraints.Float](a []T) float64 { - return math.Sqrt(float64(Dot(a, a))) +func Magnitude[T constraints.Integer | constraints.Float](a []T) T { + return T(math.Sqrt(float64(Dot(a, a)))) } // Normalize normalizes a vector to have a magnitude of 1. -func Normalize[T constraints.Integer | constraints.Float](a []T) []float64 { +func Normalize[T constraints.Integer | constraints.Float](a []T) []T { mag := Magnitude(a) - return DivideNumber(convertToFloat64Slice(a), mag) + return DivideNumber(a, mag) } // Sum computes the sum of all elements in a vector. @@ -157,9 +157,9 @@ func Product[T constraints.Integer | constraints.Float](a []T) T { } // func Mean computes the mean of a vector. -func Mean[T constraints.Integer | constraints.Float](a []T) float64 { +func Mean[T constraints.Integer | constraints.Float](a []T) T { assertNonEmpty(a) - return float64(Sum(a)) / float64(len(a)) + return T(float64(Sum(a)) / float64(len(a))) } // Min computes the minimum element in a vector. @@ -208,13 +208,13 @@ func AbsInPlace[T constraints.Integer | constraints.Float](a []T) { } // EuclidianDistance computes the Euclidian distance between two vectors. -func EuclidianDistance[T constraints.Integer | constraints.Float](a, b []T) float64 { +func EuclidianDistance[T constraints.Integer | constraints.Float](a, b []T) T { assertEqualLength(a, b) var result float64 = 0 for i := 0; i < len(a); i++ { result += math.Pow(float64(a[i]-b[i]), 2) } - return math.Sqrt(result) + return T(math.Sqrt(result)) } func assertEqualLength[T constraints.Integer | constraints.Float](a, b []T) { @@ -223,15 +223,6 @@ func assertEqualLength[T constraints.Integer | constraints.Float](a, b []T) { } } -// convertToFloat64Slice converts a slice of type []T to type []float64. -func convertToFloat64Slice[T constraints.Integer | constraints.Float](a []T) []float64 { - result := make([]float64, len(a)) - for i := range a { - result[i] = float64(a[i]) - } - return result -} - func assertNonEmpty[T constraints.Integer | constraints.Float](a []T) { if len(a) == 0 { panic("vector must be non-empty") diff --git a/sdk/go/pkg/vectors/vectors_test.go b/sdk/go/pkg/vectors/vectors_test.go index b8993be2..d15fb3a9 100644 --- a/sdk/go/pkg/vectors/vectors_test.go +++ b/sdk/go/pkg/vectors/vectors_test.go @@ -243,7 +243,7 @@ func TestProduct(t *testing.T) { func TestMean(t *testing.T) { a := []uint8{1, 2, 3} - expected := 2.0 // (1 + 2 + 3) / 3 + expected := uint8(2) // (1 + 2 + 3) / 3 result := Mean(a)