Skip to content

Commit

Permalink
Merge pull request #1213 from dolthub/daylon/root-functions
Browse files Browse the repository at this point in the history
Changed created functions to persist on the root
  • Loading branch information
Hydrocharged authored Feb 18, 2025
2 parents e36147c + 58a3ff4 commit 9677227
Show file tree
Hide file tree
Showing 17 changed files with 683 additions and 51 deletions.
38 changes: 37 additions & 1 deletion core/context.go
Original file line number Diff line number Diff line change
Expand Up @@ -16,12 +16,12 @@ package core

import (
"github.com/cockroachdb/errors"

"github.com/dolthub/dolt/go/libraries/doltcore/doltdb"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/dsess"
"github.com/dolthub/dolt/go/libraries/doltcore/sqle/resolve"
"github.com/dolthub/go-mysql-server/sql"

"github.com/dolthub/doltgresql/core/functions"
"github.com/dolthub/doltgresql/core/sequences"
"github.com/dolthub/doltgresql/core/typecollection"
)
Expand All @@ -30,6 +30,7 @@ import (
type contextValues struct {
collection *sequences.Collection
types *typecollection.TypeCollection
funcs *functions.Collection
pgCatalogCache any
}

Expand Down Expand Up @@ -63,6 +64,17 @@ func getRootFromContext(ctx *sql.Context) (*dsess.DoltSession, *RootValue, error
return session, state.WorkingRoot().(*RootValue), nil
}

// IsContextValid returns whether the context is valid for use with any of the functions in the package. If this is not
// false, then there's a high likelihood that the context is being used in a temporary scenario (such as setting up the
// database, etc.).
func IsContextValid(ctx *sql.Context) bool {
if ctx == nil {
return false
}
_, ok := ctx.Session.(*dsess.DoltSession)
return ok
}

// GetPgCatalogCache returns a cache of data for pg_catalog tables. This function should only be used by
// pg_catalog table handlers. The catalog cache instance stores generated pg_catalog table data so that
// it only has to generated table data once per query.
Expand Down Expand Up @@ -185,6 +197,26 @@ func GetSqlTableFromContext(ctx *sql.Context, databaseName string, tableName dol
return nil, nil
}

// GetFunctionsCollectionFromContext returns the functions collection from the given context. Will always return a
// collection if no error is returned.
func GetFunctionsCollectionFromContext(ctx *sql.Context) (*functions.Collection, error) {
cv, err := getContextValues(ctx)
if err != nil {
return nil, err
}
if cv.funcs == nil {
_, root, err := getRootFromContext(ctx)
if err != nil {
return nil, err
}
cv.funcs, err = root.GetFunctions(ctx)
if err != nil {
return nil, err
}
}
return cv.funcs, nil
}

// GetSequencesCollectionFromContext returns the given sequence collection from the context. Will always return a collection if
// no error is returned.
func GetSequencesCollectionFromContext(ctx *sql.Context) (*sequences.Collection, error) {
Expand Down Expand Up @@ -247,6 +279,10 @@ func CloseContextRootFinalizer(ctx *sql.Context) error {
if err != nil {
return err
}
newRoot, err = newRoot.PutFunctions(ctx, cv.funcs)
if err != nil {
return err
}
if newRoot != nil {
if err = session.SetWorkingRoot(ctx, ctx.GetCurrentDatabase(), newRoot); err != nil {
// TODO: We need a way to see if the session has a writeable working root
Expand Down
128 changes: 128 additions & 0 deletions core/functions/function.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,128 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package functions

import (
"maps"
"sync"

"github.com/cockroachdb/errors"

"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/server/plpgsql"
)

// Collection contains a collection of functions.
type Collection struct {
funcMap map[id.Function]*Function
overloadMap map[id.Function][]*Function
mutex *sync.Mutex
}

// Function represents a created function.
type Function struct {
ID id.Function
ReturnType id.Type
ParameterNames []string
ParameterTypes []id.Type
Variadic bool
IsNonDeterministic bool
Strict bool
Operations []plpgsql.InterpreterOperation
}

// GetFunction returns the function with the given ID. Returns nil if the function cannot be found.
func (pgf *Collection) GetFunction(funcID id.Function) *Function {
pgf.mutex.Lock()
defer pgf.mutex.Unlock()

if f, ok := pgf.funcMap[funcID]; ok {
return f
}
return nil
}

// GetFunctionOverloads returns the overloads for the function matching the schema and the function name. The parameter
// types are ignored when searching for overloads.
func (pgf *Collection) GetFunctionOverloads(funcID id.Function) []*Function {
pgf.mutex.Lock()
defer pgf.mutex.Unlock()

funcNameOnly := id.NewFunction(funcID.SchemaName(), funcID.FunctionName())
return pgf.overloadMap[funcNameOnly]
}

// HasFunction returns whether the function is present.
func (pgf *Collection) HasFunction(funcID id.Function) bool {
return pgf.GetFunction(funcID) != nil
}

// AddFunction adds a new function.
func (pgf *Collection) AddFunction(f *Function) error {
pgf.mutex.Lock()
defer pgf.mutex.Unlock()

if _, ok := pgf.funcMap[f.ID]; ok {
return errors.Errorf(`function "%s" already exists with same argument types`, f.ID.FunctionName())
}
pgf.funcMap[f.ID] = f
funcNameOnly := id.NewFunction(f.ID.SchemaName(), f.ID.FunctionName())
pgf.overloadMap[funcNameOnly] = append(pgf.overloadMap[funcNameOnly], f)
return nil
}

// DropFunction drops an existing function.
func (pgf *Collection) DropFunction(funcID id.Function) error {
pgf.mutex.Lock()
defer pgf.mutex.Unlock()

if _, ok := pgf.funcMap[funcID]; ok {
delete(pgf.funcMap, funcID)
funcNameOnly := id.NewFunction(funcID.SchemaName(), funcID.FunctionName())
for i, f := range pgf.overloadMap[funcNameOnly] {
if f.ID == funcID {
pgf.overloadMap[funcNameOnly] = append(pgf.overloadMap[funcNameOnly][:i], pgf.overloadMap[funcNameOnly][i+1:]...)
break
}
}
return nil
}
return errors.Errorf(`function %s does not exist`, funcID.FunctionName())
}

// IterateFunctions iterates over all functions in the collection.
func (pgf *Collection) IterateFunctions(callback func(f *Function) error) error {
pgf.mutex.Lock()
defer pgf.mutex.Unlock()

for _, f := range pgf.funcMap {
if err := callback(f); err != nil {
return err
}
}
return nil
}

// Clone returns a new *Collection with the same contents as the original.
func (pgf *Collection) Clone() *Collection {
pgf.mutex.Lock()
defer pgf.mutex.Unlock()

return &Collection{
funcMap: maps.Clone(pgf.funcMap),
overloadMap: maps.Clone(pgf.overloadMap),
mutex: &sync.Mutex{},
}
}
39 changes: 39 additions & 0 deletions core/functions/merge.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,39 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package functions

import (
"context"

"github.com/cockroachdb/errors"
)

// Merge handles merging functions on our root and their root.
func Merge(ctx context.Context, ourCollection, theirCollection, ancCollection *Collection) (*Collection, error) {
mergedCollection := ourCollection.Clone()
err := theirCollection.IterateFunctions(func(theirFunc *Function) error {
// If we don't have the sequence, then we simply add it
if !mergedCollection.HasFunction(theirFunc.ID) {
newFunc := *theirFunc
return mergedCollection.AddFunction(&newFunc)
}
// TODO: figure out a decent merge strategy
return errors.Errorf(`unable to merge "%s"`, theirFunc.ID.AsId().String())
})
if err != nil {
return nil, err
}
return mergedCollection, nil
}
120 changes: 120 additions & 0 deletions core/functions/serialization.go
Original file line number Diff line number Diff line change
@@ -0,0 +1,120 @@
// Copyright 2024 Dolthub, Inc.
//
// Licensed under the Apache License, Version 2.0 (the "License");
// you may not use this file except in compliance with the License.
// You may obtain a copy of the License at
//
// http://www.apache.org/licenses/LICENSE-2.0
//
// Unless required by applicable law or agreed to in writing, software
// distributed under the License is distributed on an "AS IS" BASIS,
// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
// See the License for the specific language governing permissions and
// limitations under the License.

package functions

import (
"context"
"sync"

"github.com/cockroachdb/errors"

"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/server/plpgsql"
"github.com/dolthub/doltgresql/utils"
)

// Serialize returns the Collection as a byte slice. If the Collection is nil, then this returns a nil slice.
func (pgf *Collection) Serialize(ctx context.Context) ([]byte, error) {
if pgf == nil {
return nil, nil
}
pgf.mutex.Lock()
defer pgf.mutex.Unlock()

// Write all of the functions to the writer
writer := utils.NewWriter(256)
writer.VariableUint(0) // Version
funcIDs := utils.GetMapKeysSorted(pgf.funcMap)
writer.VariableUint(uint64(len(funcIDs)))
for _, funcID := range funcIDs {
f := pgf.funcMap[funcID]
writer.Id(f.ID.AsId())
writer.Id(f.ReturnType.AsId())
writer.StringSlice(f.ParameterNames)
writer.IdTypeSlice(f.ParameterTypes)
writer.Bool(f.Variadic)
writer.Bool(f.IsNonDeterministic)
writer.Bool(f.Strict)
// Write the operations
writer.VariableUint(uint64(len(f.Operations)))
for _, op := range f.Operations {
writer.Uint16(uint16(op.OpCode))
writer.String(op.PrimaryData)
writer.StringSlice(op.SecondaryData)
writer.String(op.Target)
writer.Int32(int32(op.Index))
}
}

return writer.Data(), nil
}

// Deserialize returns the Collection that was serialized in the byte slice. Returns an empty Collection if data is nil
// or empty.
func Deserialize(ctx context.Context, data []byte) (*Collection, error) {
if len(data) == 0 {
return &Collection{
funcMap: make(map[id.Function]*Function),
overloadMap: make(map[id.Function][]*Function),
mutex: &sync.Mutex{},
}, nil
}
funcMap := make(map[id.Function]*Function)
overloadMap := make(map[id.Function][]*Function)
reader := utils.NewReader(data)
version := reader.VariableUint()
if version != 0 {
return nil, errors.Errorf("version %d of functions is not supported, please upgrade the server", version)
}

// Read from the reader
numOfFunctions := reader.VariableUint()
for i := uint64(0); i < numOfFunctions; i++ {
f := &Function{}
f.ID = id.Function(reader.Id())
f.ReturnType = id.Type(reader.Id())
f.ParameterNames = reader.StringSlice()
f.ParameterTypes = reader.IdTypeSlice()
f.Variadic = reader.Bool()
f.IsNonDeterministic = reader.Bool()
f.Strict = reader.Bool()
// Read the operations
opCount := reader.VariableUint()
f.Operations = make([]plpgsql.InterpreterOperation, opCount)
for opIdx := uint64(0); opIdx < opCount; opIdx++ {
op := plpgsql.InterpreterOperation{}
op.OpCode = plpgsql.OpCode(reader.Uint16())
op.PrimaryData = reader.String()
op.SecondaryData = reader.StringSlice()
op.Target = reader.String()
op.Index = int(reader.Int32())
f.Operations[opIdx] = op
}
// Add the function to each map
funcMap[f.ID] = f
funcNameOnly := id.NewFunction(f.ID.SchemaName(), f.ID.FunctionName())
overloadMap[funcNameOnly] = append(overloadMap[funcNameOnly], f)
}
if !reader.IsEmpty() {
return nil, errors.Errorf("extra data found while deserializing functions")
}

// Return the deserialized object
return &Collection{
funcMap: funcMap,
overloadMap: overloadMap,
mutex: &sync.Mutex{},
}, nil
}
2 changes: 2 additions & 0 deletions core/init.go
Original file line number Diff line number Diff line change
Expand Up @@ -19,6 +19,7 @@ import (
"github.com/dolthub/dolt/go/store/types"

"github.com/dolthub/doltgresql/core/id"
"github.com/dolthub/doltgresql/server/plpgsql"
)

// Init initializes this package.
Expand All @@ -27,5 +28,6 @@ func Init() {
doltdb.NewRootValue = newRootValue
types.DoltgresRootValueHumanReadableStringAtIndentationLevel = rootValueHumanReadableStringAtIndentationLevel
types.DoltgresRootValueWalkAddrs = rootValueWalkAddrs
plpgsql.GetTypesCollectionFromContext = GetTypesCollectionFromContext
id.RegisterListener(sequenceIDListener{}, id.Section_Table)
}
Loading

0 comments on commit 9677227

Please sign in to comment.