From 58a3ff4ad61dd5115cfb9c931d5f81bece4d6e05 Mon Sep 17 00:00:00 2001 From: Daylon Wilkins Date: Tue, 18 Feb 2025 06:02:17 -0800 Subject: [PATCH] Changed created functions to persist on the root --- core/context.go | 38 ++++++- core/functions/function.go | 128 ++++++++++++++++++++++++ core/functions/merge.go | 39 ++++++++ core/functions/serialization.go | 120 ++++++++++++++++++++++ core/init.go | 2 + core/rootvalue.go | 132 ++++++++++++++++++++----- core/storage.go | 20 ++++ flatbuffers/gen/serial/rootvalue.go | 48 ++++++++- flatbuffers/serial/rootvalue.fbs | 4 + server/ast/create_function.go | 1 + server/functions/framework/catalog.go | 7 +- server/functions/framework/provider.go | 64 ++++++++++-- server/node/create_function.go | 34 +++++-- server/plpgsql/interpreter_logic.go | 8 +- testing/go/create_function_test.go | 53 ++++++++++ utils/reader.go | 20 ++++ utils/writer.go | 16 +++ 17 files changed, 683 insertions(+), 51 deletions(-) create mode 100644 core/functions/function.go create mode 100644 core/functions/merge.go create mode 100644 core/functions/serialization.go diff --git a/core/context.go b/core/context.go index 0f536907d3..e2c3a0a58a 100644 --- a/core/context.go +++ b/core/context.go @@ -16,12 +16,12 @@ package core import ( "github.com/cockroachdb/errors" - "github.com/dolthub/dolt/go/libraries/doltcore/doltdb" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess" "github.com/dolthub/dolt/go/libraries/doltcore/sqle/resolve" "github.com/dolthub/go-mysql-server/sql" + "github.com/dolthub/doltgresql/core/functions" "github.com/dolthub/doltgresql/core/sequences" "github.com/dolthub/doltgresql/core/typecollection" ) @@ -30,6 +30,7 @@ import ( type contextValues struct { collection *sequences.Collection types *typecollection.TypeCollection + funcs *functions.Collection pgCatalogCache any } @@ -63,6 +64,17 @@ func getRootFromContext(ctx *sql.Context) (*dsess.DoltSession, *RootValue, error return session, state.WorkingRoot().(*RootValue), nil } +// IsContextValid returns whether the context is valid for use with any of the functions in the package. If this is not +// false, then there's a high likelihood that the context is being used in a temporary scenario (such as setting up the +// database, etc.). +func IsContextValid(ctx *sql.Context) bool { + if ctx == nil { + return false + } + _, ok := ctx.Session.(*dsess.DoltSession) + return ok +} + // GetPgCatalogCache returns a cache of data for pg_catalog tables. This function should only be used by // pg_catalog table handlers. The catalog cache instance stores generated pg_catalog table data so that // it only has to generated table data once per query. @@ -185,6 +197,26 @@ func GetSqlTableFromContext(ctx *sql.Context, databaseName string, tableName dol return nil, nil } +// GetFunctionsCollectionFromContext returns the functions collection from the given context. Will always return a +// collection if no error is returned. +func GetFunctionsCollectionFromContext(ctx *sql.Context) (*functions.Collection, error) { + cv, err := getContextValues(ctx) + if err != nil { + return nil, err + } + if cv.funcs == nil { + _, root, err := getRootFromContext(ctx) + if err != nil { + return nil, err + } + cv.funcs, err = root.GetFunctions(ctx) + if err != nil { + return nil, err + } + } + return cv.funcs, nil +} + // GetSequencesCollectionFromContext returns the given sequence collection from the context. Will always return a collection if // no error is returned. func GetSequencesCollectionFromContext(ctx *sql.Context) (*sequences.Collection, error) { @@ -247,6 +279,10 @@ func CloseContextRootFinalizer(ctx *sql.Context) error { if err != nil { return err } + newRoot, err = newRoot.PutFunctions(ctx, cv.funcs) + if err != nil { + return err + } if newRoot != nil { if err = session.SetWorkingRoot(ctx, ctx.GetCurrentDatabase(), newRoot); err != nil { // TODO: We need a way to see if the session has a writeable working root diff --git a/core/functions/function.go b/core/functions/function.go new file mode 100644 index 0000000000..bb50299fcd --- /dev/null +++ b/core/functions/function.go @@ -0,0 +1,128 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package functions + +import ( + "maps" + "sync" + + "github.com/cockroachdb/errors" + + "github.com/dolthub/doltgresql/core/id" + "github.com/dolthub/doltgresql/server/plpgsql" +) + +// Collection contains a collection of functions. +type Collection struct { + funcMap map[id.Function]*Function + overloadMap map[id.Function][]*Function + mutex *sync.Mutex +} + +// Function represents a created function. +type Function struct { + ID id.Function + ReturnType id.Type + ParameterNames []string + ParameterTypes []id.Type + Variadic bool + IsNonDeterministic bool + Strict bool + Operations []plpgsql.InterpreterOperation +} + +// GetFunction returns the function with the given ID. Returns nil if the function cannot be found. +func (pgf *Collection) GetFunction(funcID id.Function) *Function { + pgf.mutex.Lock() + defer pgf.mutex.Unlock() + + if f, ok := pgf.funcMap[funcID]; ok { + return f + } + return nil +} + +// GetFunctionOverloads returns the overloads for the function matching the schema and the function name. The parameter +// types are ignored when searching for overloads. +func (pgf *Collection) GetFunctionOverloads(funcID id.Function) []*Function { + pgf.mutex.Lock() + defer pgf.mutex.Unlock() + + funcNameOnly := id.NewFunction(funcID.SchemaName(), funcID.FunctionName()) + return pgf.overloadMap[funcNameOnly] +} + +// HasFunction returns whether the function is present. +func (pgf *Collection) HasFunction(funcID id.Function) bool { + return pgf.GetFunction(funcID) != nil +} + +// AddFunction adds a new function. +func (pgf *Collection) AddFunction(f *Function) error { + pgf.mutex.Lock() + defer pgf.mutex.Unlock() + + if _, ok := pgf.funcMap[f.ID]; ok { + return errors.Errorf(`function "%s" already exists with same argument types`, f.ID.FunctionName()) + } + pgf.funcMap[f.ID] = f + funcNameOnly := id.NewFunction(f.ID.SchemaName(), f.ID.FunctionName()) + pgf.overloadMap[funcNameOnly] = append(pgf.overloadMap[funcNameOnly], f) + return nil +} + +// DropFunction drops an existing function. +func (pgf *Collection) DropFunction(funcID id.Function) error { + pgf.mutex.Lock() + defer pgf.mutex.Unlock() + + if _, ok := pgf.funcMap[funcID]; ok { + delete(pgf.funcMap, funcID) + funcNameOnly := id.NewFunction(funcID.SchemaName(), funcID.FunctionName()) + for i, f := range pgf.overloadMap[funcNameOnly] { + if f.ID == funcID { + pgf.overloadMap[funcNameOnly] = append(pgf.overloadMap[funcNameOnly][:i], pgf.overloadMap[funcNameOnly][i+1:]...) + break + } + } + return nil + } + return errors.Errorf(`function %s does not exist`, funcID.FunctionName()) +} + +// IterateFunctions iterates over all functions in the collection. +func (pgf *Collection) IterateFunctions(callback func(f *Function) error) error { + pgf.mutex.Lock() + defer pgf.mutex.Unlock() + + for _, f := range pgf.funcMap { + if err := callback(f); err != nil { + return err + } + } + return nil +} + +// Clone returns a new *Collection with the same contents as the original. +func (pgf *Collection) Clone() *Collection { + pgf.mutex.Lock() + defer pgf.mutex.Unlock() + + return &Collection{ + funcMap: maps.Clone(pgf.funcMap), + overloadMap: maps.Clone(pgf.overloadMap), + mutex: &sync.Mutex{}, + } +} diff --git a/core/functions/merge.go b/core/functions/merge.go new file mode 100644 index 0000000000..6c83bab6ff --- /dev/null +++ b/core/functions/merge.go @@ -0,0 +1,39 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package functions + +import ( + "context" + + "github.com/cockroachdb/errors" +) + +// Merge handles merging functions on our root and their root. +func Merge(ctx context.Context, ourCollection, theirCollection, ancCollection *Collection) (*Collection, error) { + mergedCollection := ourCollection.Clone() + err := theirCollection.IterateFunctions(func(theirFunc *Function) error { + // If we don't have the sequence, then we simply add it + if !mergedCollection.HasFunction(theirFunc.ID) { + newFunc := *theirFunc + return mergedCollection.AddFunction(&newFunc) + } + // TODO: figure out a decent merge strategy + return errors.Errorf(`unable to merge "%s"`, theirFunc.ID.AsId().String()) + }) + if err != nil { + return nil, err + } + return mergedCollection, nil +} diff --git a/core/functions/serialization.go b/core/functions/serialization.go new file mode 100644 index 0000000000..1bb3b39948 --- /dev/null +++ b/core/functions/serialization.go @@ -0,0 +1,120 @@ +// Copyright 2024 Dolthub, Inc. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package functions + +import ( + "context" + "sync" + + "github.com/cockroachdb/errors" + + "github.com/dolthub/doltgresql/core/id" + "github.com/dolthub/doltgresql/server/plpgsql" + "github.com/dolthub/doltgresql/utils" +) + +// Serialize returns the Collection as a byte slice. If the Collection is nil, then this returns a nil slice. +func (pgf *Collection) Serialize(ctx context.Context) ([]byte, error) { + if pgf == nil { + return nil, nil + } + pgf.mutex.Lock() + defer pgf.mutex.Unlock() + + // Write all of the functions to the writer + writer := utils.NewWriter(256) + writer.VariableUint(0) // Version + funcIDs := utils.GetMapKeysSorted(pgf.funcMap) + writer.VariableUint(uint64(len(funcIDs))) + for _, funcID := range funcIDs { + f := pgf.funcMap[funcID] + writer.Id(f.ID.AsId()) + writer.Id(f.ReturnType.AsId()) + writer.StringSlice(f.ParameterNames) + writer.IdTypeSlice(f.ParameterTypes) + writer.Bool(f.Variadic) + writer.Bool(f.IsNonDeterministic) + writer.Bool(f.Strict) + // Write the operations + writer.VariableUint(uint64(len(f.Operations))) + for _, op := range f.Operations { + writer.Uint16(uint16(op.OpCode)) + writer.String(op.PrimaryData) + writer.StringSlice(op.SecondaryData) + writer.String(op.Target) + writer.Int32(int32(op.Index)) + } + } + + return writer.Data(), nil +} + +// Deserialize returns the Collection that was serialized in the byte slice. Returns an empty Collection if data is nil +// or empty. +func Deserialize(ctx context.Context, data []byte) (*Collection, error) { + if len(data) == 0 { + return &Collection{ + funcMap: make(map[id.Function]*Function), + overloadMap: make(map[id.Function][]*Function), + mutex: &sync.Mutex{}, + }, nil + } + funcMap := make(map[id.Function]*Function) + overloadMap := make(map[id.Function][]*Function) + reader := utils.NewReader(data) + version := reader.VariableUint() + if version != 0 { + return nil, errors.Errorf("version %d of functions is not supported, please upgrade the server", version) + } + + // Read from the reader + numOfFunctions := reader.VariableUint() + for i := uint64(0); i < numOfFunctions; i++ { + f := &Function{} + f.ID = id.Function(reader.Id()) + f.ReturnType = id.Type(reader.Id()) + f.ParameterNames = reader.StringSlice() + f.ParameterTypes = reader.IdTypeSlice() + f.Variadic = reader.Bool() + f.IsNonDeterministic = reader.Bool() + f.Strict = reader.Bool() + // Read the operations + opCount := reader.VariableUint() + f.Operations = make([]plpgsql.InterpreterOperation, opCount) + for opIdx := uint64(0); opIdx < opCount; opIdx++ { + op := plpgsql.InterpreterOperation{} + op.OpCode = plpgsql.OpCode(reader.Uint16()) + op.PrimaryData = reader.String() + op.SecondaryData = reader.StringSlice() + op.Target = reader.String() + op.Index = int(reader.Int32()) + f.Operations[opIdx] = op + } + // Add the function to each map + funcMap[f.ID] = f + funcNameOnly := id.NewFunction(f.ID.SchemaName(), f.ID.FunctionName()) + overloadMap[funcNameOnly] = append(overloadMap[funcNameOnly], f) + } + if !reader.IsEmpty() { + return nil, errors.Errorf("extra data found while deserializing functions") + } + + // Return the deserialized object + return &Collection{ + funcMap: funcMap, + overloadMap: overloadMap, + mutex: &sync.Mutex{}, + }, nil +} diff --git a/core/init.go b/core/init.go index 8921a3335d..8b49705c80 100644 --- a/core/init.go +++ b/core/init.go @@ -19,6 +19,7 @@ import ( "github.com/dolthub/dolt/go/store/types" "github.com/dolthub/doltgresql/core/id" + "github.com/dolthub/doltgresql/server/plpgsql" ) // Init initializes this package. @@ -27,5 +28,6 @@ func Init() { doltdb.NewRootValue = newRootValue types.DoltgresRootValueHumanReadableStringAtIndentationLevel = rootValueHumanReadableStringAtIndentationLevel types.DoltgresRootValueWalkAddrs = rootValueWalkAddrs + plpgsql.GetTypesCollectionFromContext = GetTypesCollectionFromContext id.RegisterListener(sequenceIDListener{}, id.Section_Table) } diff --git a/core/rootvalue.go b/core/rootvalue.go index e160f9a100..0f8e82c9f7 100644 --- a/core/rootvalue.go +++ b/core/rootvalue.go @@ -30,6 +30,7 @@ import ( "github.com/dolthub/dolt/go/store/prolly/tree" "github.com/dolthub/dolt/go/store/types" + "github.com/dolthub/doltgresql/core/functions" "github.com/dolthub/doltgresql/core/id" "github.com/dolthub/doltgresql/core/sequences" "github.com/dolthub/doltgresql/core/typecollection" @@ -206,11 +207,11 @@ func (root *RootValue) GetForeignKeyCollection(ctx context.Context) (*doltdb.For return root.fkc.Copy(), nil } -// GetSequences returns all sequences that are on the root. -func (root *RootValue) GetSequences(ctx context.Context) (*sequences.Collection, error) { - h := root.st.GetSequences() +// GetFunctions returns all functions that are on the root. +func (root *RootValue) GetFunctions(ctx context.Context) (*functions.Collection, error) { + h := root.st.GetFunctions() if h.IsEmpty() { - return sequences.Deserialize(ctx, nil) + return functions.Deserialize(ctx, nil) } dataValue, err := root.vrw.ReadValue(ctx, h) if err != nil { @@ -224,16 +225,16 @@ func (root *RootValue) GetSequences(ctx context.Context) (*sequences.Collection, return nil, err } if uint64(n) != dataBlobLength { - return nil, errors.Errorf("wanted %d bytes from blob for sequences, got %d", dataBlobLength, n) + return nil, errors.Errorf("wanted %d bytes from blob for functions, got %d", dataBlobLength, n) } - return sequences.Deserialize(ctx, data) + return functions.Deserialize(ctx, data) } -// GetTypes returns all types that are on the root. -func (root *RootValue) GetTypes(ctx context.Context) (*typecollection.TypeCollection, error) { - h := root.st.GetTypes() +// GetSequences returns all sequences that are on the root. +func (root *RootValue) GetSequences(ctx context.Context) (*sequences.Collection, error) { + h := root.st.GetSequences() if h.IsEmpty() { - return typecollection.Deserialize(ctx, nil) + return sequences.Deserialize(ctx, nil) } dataValue, err := root.vrw.ReadValue(ctx, h) if err != nil { @@ -247,9 +248,9 @@ func (root *RootValue) GetTypes(ctx context.Context) (*typecollection.TypeCollec return nil, err } if uint64(n) != dataBlobLength { - return nil, errors.Errorf("wanted %d bytes from blob for types, got %d", dataBlobLength, n) + return nil, errors.Errorf("wanted %d bytes from blob for sequences, got %d", dataBlobLength, n) } - return typecollection.Deserialize(ctx, data) + return sequences.Deserialize(ctx, data) } // GetTable implements the interface doltdb.RootValue. @@ -301,6 +302,29 @@ func (root *RootValue) GetTableNames(ctx context.Context, schemaName string) ([] return names, nil } +// GetTypes returns all types that are on the root. +func (root *RootValue) GetTypes(ctx context.Context) (*typecollection.TypeCollection, error) { + h := root.st.GetTypes() + if h.IsEmpty() { + return typecollection.Deserialize(ctx, nil) + } + dataValue, err := root.vrw.ReadValue(ctx, h) + if err != nil { + return nil, err + } + dataBlob := dataValue.(types.Blob) + dataBlobLength := dataBlob.Len() + data := make([]byte, dataBlobLength) + n, err := dataBlob.ReadAt(context.Background(), data, 0) + if err != nil && err != io.EOF { + return nil, err + } + if uint64(n) != dataBlobLength { + return nil, errors.Errorf("wanted %d bytes from blob for types, got %d", dataBlobLength, n) + } + return typecollection.Deserialize(ctx, data) +} + // HandlePostMerge implements the interface doltdb.RootValue. func (root *RootValue) HandlePostMerge(ctx context.Context, ourRoot, theirRoot, ancRoot doltdb.RootValue) (doltdb.RootValue, error) { // Handle sequences @@ -309,7 +333,33 @@ func (root *RootValue) HandlePostMerge(ctx context.Context, ourRoot, theirRoot, return nil, err } // Handle types - return root.handlePostTypesMerge(ctx, ourRoot, theirRoot, ancRoot) + _, err = root.handlePostTypesMerge(ctx, ourRoot, theirRoot, ancRoot) + if err != nil { + return nil, err + } + // Handle functions + return root.handlePostFunctionsMerge(ctx, ourRoot, theirRoot, ancRoot) +} + +// handlePostFunctionsMerge merges functions. +func (root *RootValue) handlePostFunctionsMerge(ctx context.Context, ourRoot, theirRoot, ancRoot doltdb.RootValue) (doltdb.RootValue, error) { + ourFunctions, err := ourRoot.(*RootValue).GetFunctions(ctx) + if err != nil { + return nil, err + } + theirFunctions, err := theirRoot.(*RootValue).GetFunctions(ctx) + if err != nil { + return nil, err + } + ancFunctions, err := ancRoot.(*RootValue).GetFunctions(ctx) + if err != nil { + return nil, err + } + mergedFunctions, err := functions.Merge(ctx, ourFunctions, theirFunctions, ancFunctions) + if err != nil { + return nil, err + } + return root.PutFunctions(ctx, mergedFunctions) } // handlePostSequencesMerge merges sequences. @@ -441,34 +491,37 @@ func (root *RootValue) NomsValue() types.Value { return root.st.nomsValue() } -// PutTypes writes the given types to the returned root value. -func (root *RootValue) PutTypes(ctx context.Context, typ *typecollection.TypeCollection) (*RootValue, error) { - data, err := typ.Serialize(ctx) +// PutForeignKeyCollection implements the interface doltdb.RootValue. +func (root *RootValue) PutForeignKeyCollection(ctx context.Context, fkc *doltdb.ForeignKeyCollection) (doltdb.RootValue, error) { + value, err := doltdb.SerializeForeignKeys(ctx, root.vrw, fkc) if err != nil { return nil, err } - dataBlob, err := types.NewBlob(ctx, root.vrw, bytes.NewReader(data)) + newStorage, err := root.st.SetForeignKeyMap(ctx, root.vrw, value) if err != nil { return nil, err } - ref, err := root.vrw.WriteValue(ctx, dataBlob) + return root.withStorage(newStorage), nil +} + +// PutFunctions writes the given functions to the returned root value. +func (root *RootValue) PutFunctions(ctx context.Context, funcCollection *functions.Collection) (*RootValue, error) { + if funcCollection == nil { + return root, nil + } + data, err := funcCollection.Serialize(ctx) if err != nil { return nil, err } - newStorage, err := root.st.SetTypes(ctx, ref.TargetHash()) + dataBlob, err := types.NewBlob(ctx, root.vrw, bytes.NewReader(data)) if err != nil { return nil, err } - return root.withStorage(newStorage), nil -} - -// PutForeignKeyCollection implements the interface doltdb.RootValue. -func (root *RootValue) PutForeignKeyCollection(ctx context.Context, fkc *doltdb.ForeignKeyCollection) (doltdb.RootValue, error) { - value, err := doltdb.SerializeForeignKeys(ctx, root.vrw, fkc) + ref, err := root.vrw.WriteValue(ctx, dataBlob) if err != nil { return nil, err } - newStorage, err := root.st.SetForeignKeyMap(ctx, root.vrw, value) + newStorage, err := root.st.SetFunctions(ctx, ref.TargetHash()) if err != nil { return nil, err } @@ -477,6 +530,9 @@ func (root *RootValue) PutForeignKeyCollection(ctx context.Context, fkc *doltdb. // PutSequences writes the given sequences to the returned root value. func (root *RootValue) PutSequences(ctx context.Context, seq *sequences.Collection) (*RootValue, error) { + if seq == nil { + return root, nil + } data, err := seq.Serialize(ctx) if err != nil { return nil, err @@ -512,6 +568,30 @@ func (root *RootValue) PutTable(ctx context.Context, tName doltdb.TableName, tab return root.putTable(ctx, tName, tableRef) } +// PutTypes writes the given types to the returned root value. +func (root *RootValue) PutTypes(ctx context.Context, typ *typecollection.TypeCollection) (*RootValue, error) { + if typ == nil { + return root, nil + } + data, err := typ.Serialize(ctx) + if err != nil { + return nil, err + } + dataBlob, err := types.NewBlob(ctx, root.vrw, bytes.NewReader(data)) + if err != nil { + return nil, err + } + ref, err := root.vrw.WriteValue(ctx, dataBlob) + if err != nil { + return nil, err + } + newStorage, err := root.st.SetTypes(ctx, ref.TargetHash()) + if err != nil { + return nil, err + } + return root.withStorage(newStorage), nil +} + // RemoveTables implements the interface doltdb.RootValue. func (root *RootValue) RemoveTables( ctx context.Context, diff --git a/core/storage.go b/core/storage.go index 9b2981de88..8d80f9e8f3 100644 --- a/core/storage.go +++ b/core/storage.go @@ -36,6 +36,17 @@ type rootStorage struct { srv *serial.RootValue } +// SetFunctions sets the function hash and returns a new storage object. +func (r rootStorage) SetFunctions(ctx context.Context, h hash.Hash) (rootStorage, error) { + if len(r.srv.FunctionsBytes()) > 0 { + ret := r.clone() + copy(ret.srv.FunctionsBytes(), h[:]) + return ret, nil + } else { + return r.clone(), nil + } +} + // SetSequences sets the sequence hash and returns a new storage object. func (r rootStorage) SetSequences(ctx context.Context, h hash.Hash) (rootStorage, error) { if len(r.srv.SequencesBytes()) > 0 { @@ -116,6 +127,15 @@ func (r rootStorage) SetSchemas(ctx context.Context, dbSchemas []schema.Database return rootStorage{msg}, nil } +// GetFunctions returns the functions hash. +func (r rootStorage) GetFunctions() hash.Hash { + hashBytes := r.srv.FunctionsBytes() + if len(hashBytes) == 0 { + return hash.Hash{} + } + return hash.New(hashBytes) +} + // GetSequences returns the sequence hash. func (r rootStorage) GetSequences() hash.Hash { hashBytes := r.srv.SequencesBytes() diff --git a/flatbuffers/gen/serial/rootvalue.go b/flatbuffers/gen/serial/rootvalue.go index ca74118c3a..a8a0810ecf 100644 --- a/flatbuffers/gen/serial/rootvalue.go +++ b/flatbuffers/gen/serial/rootvalue.go @@ -235,7 +235,41 @@ func (rcv *RootValue) MutateTypes(j int, n byte) bool { return false } -const RootValueNumFields = 6 +func (rcv *RootValue) Functions(j int) byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(18)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.GetByte(a + flatbuffers.UOffsetT(j*1)) + } + return 0 +} + +func (rcv *RootValue) FunctionsLength() int { + o := flatbuffers.UOffsetT(rcv._tab.Offset(18)) + if o != 0 { + return rcv._tab.VectorLen(o) + } + return 0 +} + +func (rcv *RootValue) FunctionsBytes() []byte { + o := flatbuffers.UOffsetT(rcv._tab.Offset(18)) + if o != 0 { + return rcv._tab.ByteVector(o + rcv._tab.Pos) + } + return nil +} + +func (rcv *RootValue) MutateFunctions(j int, n byte) bool { + o := flatbuffers.UOffsetT(rcv._tab.Offset(18)) + if o != 0 { + a := rcv._tab.Vector(o) + return rcv._tab.MutateByte(a+flatbuffers.UOffsetT(j*1), n) + } + return false +} + +const RootValueNumFields = 8 func RootValueStart(builder *flatbuffers.Builder) { builder.StartObject(RootValueNumFields) @@ -270,6 +304,18 @@ func RootValueAddSequences(builder *flatbuffers.Builder, sequences flatbuffers.U func RootValueStartSequencesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { return builder.StartVector(1, numElems, 1) } +func RootValueAddTypes(builder *flatbuffers.Builder, types flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(6, flatbuffers.UOffsetT(types), 0) +} +func RootValueStartTypesVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} +func RootValueAddFunctions(builder *flatbuffers.Builder, functions flatbuffers.UOffsetT) { + builder.PrependUOffsetTSlot(7, flatbuffers.UOffsetT(functions), 0) +} +func RootValueStartFunctionsVector(builder *flatbuffers.Builder, numElems int) flatbuffers.UOffsetT { + return builder.StartVector(1, numElems, 1) +} func RootValueEnd(builder *flatbuffers.Builder) flatbuffers.UOffsetT { return builder.EndObject() } diff --git a/flatbuffers/serial/rootvalue.fbs b/flatbuffers/serial/rootvalue.fbs index bead004c0f..38757b08a4 100644 --- a/flatbuffers/serial/rootvalue.fbs +++ b/flatbuffers/serial/rootvalue.fbs @@ -29,6 +29,10 @@ table RootValue { schemas:[DatabaseSchema]; sequences:[ubyte]; + + types:[ubyte]; + + functions:[ubyte]; } table DatabaseSchema { diff --git a/server/ast/create_function.go b/server/ast/create_function.go index 5fe17d5b83..0056335875 100644 --- a/server/ast/create_function.go +++ b/server/ast/create_function.go @@ -79,6 +79,7 @@ func nodeCreateFunction(ctx *Context, node *tree.CreateFunction) (vitess.Stateme Statement: pgnodes.NewCreateFunction( tableName.Table(), schemaName, + node.Replace, retType, paramNames, paramTypes, diff --git a/server/functions/framework/catalog.go b/server/functions/framework/catalog.go index 5bbca1e861..ff92b6645a 100644 --- a/server/functions/framework/catalog.go +++ b/server/functions/framework/catalog.go @@ -35,12 +35,7 @@ var initializedFunctions = false // from within an init(). func RegisterFunction(f FunctionInterface) { if initializedFunctions { - // TODO: this should be able to handle overloads - name := strings.ToLower(f.GetName()) - if err := validateFunction(name, []FunctionInterface{f}); err != nil { - panic(err) // TODO: replace panics here with errors - } - compileNonOperatorFunction(name, []FunctionInterface{f}) + panic("attempted to register a function after the init() phase") } switch f := f.(type) { case Function0: diff --git a/server/functions/framework/provider.go b/server/functions/framework/provider.go index a66c7a9050..1c19a6914e 100644 --- a/server/functions/framework/provider.go +++ b/server/functions/framework/provider.go @@ -14,7 +14,13 @@ package framework -import "github.com/dolthub/go-mysql-server/sql" +import ( + "github.com/dolthub/go-mysql-server/sql" + + "github.com/dolthub/doltgresql/core" + "github.com/dolthub/doltgresql/core/id" + pgtypes "github.com/dolthub/doltgresql/server/types" +) // FunctionProvider is the special sql.FunctionProvider for Doltgres that allows us to handle functions that // are created by users. @@ -25,13 +31,53 @@ var _ sql.FunctionProvider = (*FunctionProvider)(nil) // Function implements the interface sql.FunctionProvider. func (fp *FunctionProvider) Function(ctx *sql.Context, name string) (sql.Function, bool) { // TODO: this should be configurable from within Dolt, rather than set on an external variable - // TODO: user functions should be accessible from the context, just like how sequences and types are handled - // For now, this just reads our global map (which also needs to be changed, since functions should not be global) - if f, ok := compiledCatalog[name]; ok { - return sql.FunctionN{ - Name: name, - Fn: f, - }, true + if !core.IsContextValid(ctx) { + return nil, false + } + funcCollection, err := core.GetFunctionsCollectionFromContext(ctx) + if err != nil { + return nil, false + } + typesCollection, err := core.GetTypesCollectionFromContext(ctx) + if err != nil { + return nil, false + } + funcName := id.NewFunction("pg_catalog", name) + overloads := funcCollection.GetFunctionOverloads(funcName) + if len(overloads) == 0 { + return nil, false + } + + overloadTree := NewOverloads() + for _, overload := range overloads { + returnType, ok := typesCollection.GetType(overload.ReturnType) + if !ok { + return nil, false + } + paramTypes := make([]*pgtypes.DoltgresType, len(overload.ParameterTypes)) + for i, paramType := range overload.ParameterTypes { + paramTypes[i], ok = typesCollection.GetType(paramType) + if !ok { + return nil, false + } + } + if err = overloadTree.Add(InterpretedFunction{ + ID: overload.ID, + ReturnType: returnType, + ParameterNames: overload.ParameterNames, + ParameterTypes: paramTypes, + Variadic: overload.Variadic, + IsNonDeterministic: overload.IsNonDeterministic, + Strict: overload.Strict, + Statements: overload.Operations, + }); err != nil { + return nil, false + } } - return nil, false + return sql.FunctionN{ + Name: name, + Fn: func(params ...sql.Expression) (sql.Expression, error) { + return NewCompiledFunction(name, params, overloadTree, false), nil + }, + }, true } diff --git a/server/node/create_function.go b/server/node/create_function.go index 846c1a6218..1a227274aa 100644 --- a/server/node/create_function.go +++ b/server/node/create_function.go @@ -19,8 +19,10 @@ import ( "github.com/dolthub/go-mysql-server/sql/plan" vitess "github.com/dolthub/vitess/go/vt/sqlparser" + "github.com/dolthub/doltgresql/core" + "github.com/dolthub/doltgresql/core/functions" + "github.com/dolthub/doltgresql/core/id" - "github.com/dolthub/doltgresql/server/functions/framework" "github.com/dolthub/doltgresql/server/plpgsql" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -29,6 +31,7 @@ import ( type CreateFunction struct { FunctionName string SchemaName string + Replace bool ReturnType *pgtypes.DoltgresType ParameterNames []string ParameterTypes []*pgtypes.DoltgresType @@ -43,6 +46,7 @@ var _ vitess.Injectable = (*CreateFunction)(nil) func NewCreateFunction( functionName string, schemaName string, + replace bool, retType *pgtypes.DoltgresType, paramNames []string, paramTypes []*pgtypes.DoltgresType, @@ -51,6 +55,7 @@ func NewCreateFunction( return &CreateFunction{ FunctionName: functionName, SchemaName: schemaName, + Replace: replace, ReturnType: retType, ParameterNames: paramNames, ParameterTypes: paramTypes, @@ -80,16 +85,33 @@ func (c *CreateFunction) RowIter(ctx *sql.Context, r sql.Row) (sql.RowIter, erro for i, typ := range c.ParameterTypes { idTypes[i] = typ.ID } - framework.RegisterFunction(framework.InterpretedFunction{ - ID: id.NewFunction(c.SchemaName, c.FunctionName, idTypes...), - ReturnType: c.ReturnType, + funcCollection, err := core.GetFunctionsCollectionFromContext(ctx) + if err != nil { + return nil, err + } + paramTypes := make([]id.Type, len(c.ParameterTypes)) + for i, paramType := range c.ParameterTypes { + paramTypes[i] = paramType.ID + } + funcID := id.NewFunction(c.SchemaName, c.FunctionName, idTypes...) + if c.Replace && funcCollection.HasFunction(funcID) { + if err = funcCollection.DropFunction(funcID); err != nil { + return nil, err + } + } + err = funcCollection.AddFunction(&functions.Function{ + ID: funcID, + ReturnType: c.ReturnType.ID, ParameterNames: c.ParameterNames, - ParameterTypes: c.ParameterTypes, + ParameterTypes: paramTypes, Variadic: false, // TODO: implement this IsNonDeterministic: true, Strict: c.Strict, - Statements: c.Statements, + Operations: c.Statements, }) + if err != nil { + return nil, err + } return sql.RowsToRowIter(), nil } diff --git a/server/plpgsql/interpreter_logic.go b/server/plpgsql/interpreter_logic.go index 036ac4fe7b..bded121758 100644 --- a/server/plpgsql/interpreter_logic.go +++ b/server/plpgsql/interpreter_logic.go @@ -20,8 +20,8 @@ import ( "github.com/dolthub/go-mysql-server/sql" "github.com/dolthub/go-mysql-server/sql/analyzer" - "github.com/dolthub/doltgresql/core" "github.com/dolthub/doltgresql/core/id" + "github.com/dolthub/doltgresql/core/typecollection" pgtypes "github.com/dolthub/doltgresql/server/types" ) @@ -36,6 +36,10 @@ type InterpretedFunction interface { QuerySingleReturn(ctx *sql.Context, stack InterpreterStack, stmt string, targetType *pgtypes.DoltgresType, bindings []string) (val any, err error) } +// GetTypesCollectionFromContext is declared within the core package, but is assigned to this variable to work around +// import cycles. +var GetTypesCollectionFromContext func(ctx *sql.Context) (*typecollection.TypeCollection, error) + // Call runs the contained operations on the given runner. func Call(ctx *sql.Context, iFunc InterpretedFunction, runner analyzer.StatementRunner, paramsAndReturn []*pgtypes.DoltgresType, vals []any) (any, error) { // Set up the initial state of the function @@ -84,7 +88,7 @@ func Call(ctx *sql.Context, iFunc InterpretedFunction, runner analyzer.Statement case OpCode_Case: // TODO: implement case OpCode_Declare: - typeCollection, err := core.GetTypesCollectionFromContext(ctx) + typeCollection, err := GetTypesCollectionFromContext(ctx) if err != nil { return nil, err } diff --git a/testing/go/create_function_test.go b/testing/go/create_function_test.go index aa49bfb26e..bf1e06b470 100644 --- a/testing/go/create_function_test.go +++ b/testing/go/create_function_test.go @@ -398,5 +398,58 @@ $$ LANGUAGE plpgsql;`, }, }, }, + { + Name: "Overloading", + SetUpScript: []string{`CREATE FUNCTION interpreted_overload(input TEXT) RETURNS TEXT AS $$ +DECLARE + var1 TEXT; +BEGIN + IF length(input) > 3 THEN + var1 := input || '_long'; + ELSE + var1 := input; + END IF; + RETURN var1; +END; +$$ LANGUAGE plpgsql;`, + `CREATE FUNCTION interpreted_overload(input INT4) RETURNS INT4 AS $$ +DECLARE + var1 INT4; +BEGIN + IF input > 3 THEN + var1 := -input; + ELSE + var1 := input; + END IF; + RETURN var1; +END; +$$ LANGUAGE plpgsql;`}, + Assertions: []ScriptTestAssertion{ + { + Query: "SELECT interpreted_overload('abc');", + Expected: []sql.Row{ + {"abc"}, + }, + }, + { + Query: "SELECT interpreted_overload('abcd');", + Expected: []sql.Row{ + {"abcd_long"}, + }, + }, + { + Query: "SELECT interpreted_overload(3);", + Expected: []sql.Row{ + {3}, + }, + }, + { + Query: "SELECT interpreted_overload(4);", + Expected: []sql.Row{ + {-4}, + }, + }, + }, + }, }) } diff --git a/utils/reader.go b/utils/reader.go index 19cef53cf9..3b4e5bfa0d 100644 --- a/utils/reader.go +++ b/utils/reader.go @@ -349,6 +349,26 @@ func (reader *Reader) StringSlice() []string { return vals } +// IdSlice reads a slice of internal IDs. +func (reader *Reader) IdSlice() []id.Id { + count := reader.VariableUint() + vals := make([]id.Id, count) + for i := uint64(0); i < count; i++ { + vals[i] = reader.Id() + } + return vals +} + +// IdTypeSlice reads a slice of internal type IDs. +func (reader *Reader) IdTypeSlice() []id.Type { + count := reader.VariableUint() + vals := make([]id.Type, count) + for i := uint64(0); i < count; i++ { + vals[i] = id.Type(reader.Id()) + } + return vals +} + // IsEmpty returns true when all of the data has been read. func (reader *Reader) IsEmpty() bool { return reader.offset >= uint64(len(reader.buf)) diff --git a/utils/writer.go b/utils/writer.go index 5607e6f7c8..bbf6094ef9 100644 --- a/utils/writer.go +++ b/utils/writer.go @@ -273,6 +273,22 @@ func (writer *Writer) StringSlice(vals []string) { } } +// IdSlice writes a slice of internal IDs. +func (writer *Writer) IdSlice(vals []id.Id) { + writer.VariableUint(uint64(len(vals))) + for i := range vals { + writer.Id(vals[i]) + } +} + +// IdTypeSlice writes a slice of internal type IDs. +func (writer *Writer) IdTypeSlice(vals []id.Type) { + writer.VariableUint(uint64(len(vals))) + for i := range vals { + writer.Id(vals[i].AsId()) + } +} + // Data returns the data written to the Writer. func (writer *Writer) Data() []byte { return writer.buf.Bytes()