Skip to content

Commit

Permalink
Add template mapping functions
Browse files Browse the repository at this point in the history
These are handy to deal with properties in protobufs within minder's
templates.

Signed-off-by: Juan Antonio Osorio <[email protected]>
  • Loading branch information
JAORMX committed Oct 10, 2024
1 parent ae5c149 commit 2f6338f
Show file tree
Hide file tree
Showing 2 changed files with 219 additions and 4 deletions.
73 changes: 69 additions & 4 deletions internal/util/safe_template.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@ import (
"fmt"
htmltemplate "html/template"
"io"
"reflect"
"text/template"

"github.com/rs/zerolog"
Expand All @@ -32,6 +33,17 @@ var (
ErrExceededSizeLimit = errors.New("exceeded size limit")
)

var (
// TemplateFuncs is a map of functions that can be used in templates
// It introduces two custom functions:
// - asMap: converts a structpb (or anything that implements the AsMap function call) to a map
// - mapGet: returns the value of a key in a map
TemplateFuncs = template.FuncMap{
"asMap": asMap,
"mapGet": mapGet,
}
)

// SafeTemplate is a `template` wrapper that ensures that the template is
// rendered in a safe and secure manner. That is, with memory limits
// and timeouts.
Expand All @@ -44,9 +56,56 @@ type templater interface {
Name() string
}

// This is a utility interface that allows us to accept any type
type asMapper interface {
AsMap() map[string]interface{}
}

// asMap converts a structpb to a map
func asMap(s any) (reflect.Value, error) {
if s == nil {
return reflect.Value{}, fmt.Errorf("asMap called with nil")
}

inspb, ok := s.(asMapper)
if !ok {
return reflect.Value{}, fmt.Errorf("invalid type: %T", s)
}

return reflect.ValueOf(inspb.AsMap()), nil
}

// mapGet returns the value of a key in a map
// The map could be a map[string]interface{} or a asMapper
// So we need to handle both cases
func mapGet(m any, key string) (reflect.Value, error) {
if m == nil {
return reflect.Value{}, fmt.Errorf("map is nil")
}

// Check if the map is a map[string]interface{}
if mm, ok := m.(map[string]interface{}); ok {
return valueOfKey(mm, key)
}

if mm, ok := m.(asMapper); ok {
mm := mm.AsMap()
return valueOfKey(mm, key)
}

return reflect.Value{}, fmt.Errorf("invalid type: %T", m)
}

func valueOfKey(m map[string]interface{}, key string) (reflect.Value, error) {
if v, ok := m[key]; ok {
return reflect.ValueOf(v), nil
}
return reflect.Value{}, fmt.Errorf("key not found: %s", key)
}

// NewSafeTextTemplate creates a new SafeTemplate for text templates
func NewSafeTextTemplate(tmpl *string, name string) (*SafeTemplate, error) {
t, err := parseNewTextTemplate(tmpl, name)
t, err := parseNewTextTemplate(tmpl, name, TemplateFuncs)
if err != nil {
return nil, err
}
Expand All @@ -58,7 +117,7 @@ func NewSafeTextTemplate(tmpl *string, name string) (*SafeTemplate, error) {

// NewSafeHTMLTemplate creates a new SafeTemplate for HTML templates
func NewSafeHTMLTemplate(tmpl *string, name string) (*SafeTemplate, error) {
t, err := parseNewHtmlTemplate(tmpl, name)
t, err := parseNewHtmlTemplate(tmpl, name, TemplateFuncs)
if err != nil {
return nil, err
}
Expand Down Expand Up @@ -96,12 +155,15 @@ func (t *SafeTemplate) Execute(ctx context.Context, w io.Writer, data any, limit
}

// parseNewTextTemplate parses a named template from a string, ensuring it is not empty
func parseNewTextTemplate(tmpl *string, name string) (*template.Template, error) {
func parseNewTextTemplate(tmpl *string, name string, fnmap template.FuncMap) (*template.Template, error) {
if tmpl == nil || len(*tmpl) == 0 {
return nil, fmt.Errorf("missing template")
}

t := template.New(name).Option("missingkey=error")
if fnmap != nil {
t = t.Funcs(fnmap)
}
t, err := t.Parse(*tmpl)
if err != nil {
return nil, fmt.Errorf("cannot parse template: %w", err)
Expand All @@ -111,12 +173,15 @@ func parseNewTextTemplate(tmpl *string, name string) (*template.Template, error)
}

// parseNewHtmlTemplate parses a named template from a string, ensuring it is not empty
func parseNewHtmlTemplate(tmpl *string, name string) (*htmltemplate.Template, error) {
func parseNewHtmlTemplate(tmpl *string, name string, fnmap template.FuncMap) (*htmltemplate.Template, error) {
if tmpl == nil || len(*tmpl) == 0 {
return nil, fmt.Errorf("missing template")
}

t := htmltemplate.New(name).Option("missingkey=error")
if fnmap != nil {
t = t.Funcs(fnmap)
}
t, err := t.Parse(*tmpl)
if err != nil {
return nil, fmt.Errorf("cannot parse template: %w", err)
Expand Down
150 changes: 150 additions & 0 deletions internal/util/safe_template_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,8 @@ import (
"testing"

"github.com/stretchr/testify/assert"
"github.com/stretchr/testify/require"
structpb "google.golang.org/protobuf/types/known/structpb"

"github.com/stacklok/minder/internal/util"
)
Expand Down Expand Up @@ -208,3 +210,151 @@ func verybigstring(n int) string {
}
return s
}

func TestRenderStructPB(t *testing.T) {
t.Parallel()

const limit = 1024

type args struct {
tmpl string
s any
}
tests := []struct {
name string
args args
expected string
wantErr bool
}{
{
name: "asMap: valid template",
args: args{
tmpl: "{{ with $m := asMap . }}{{ $m.name }}{{ end }}",
s: &structpb.Struct{
Fields: map[string]*structpb.Value{
"name": {
Kind: &structpb.Value_StringValue{
StringValue: "test",
},
},
},
},
},
expected: "test",
wantErr: false,
},
{
name: "asMap: using wrong key",
args: args{
tmpl: "{{ with $m := asMap . }}{{ $m.name2 }}{{ end }}",
s: &structpb.Struct{
Fields: map[string]*structpb.Value{
"name": {
Kind: &structpb.Value_StringValue{
StringValue: "test",
},
},
},
},
},
expected: "",
wantErr: true,
},
{
name: "asMap: using wrong type",
args: args{
tmpl: "{{ with $m := asMap . }}{{ $m.name }}{{ end }}",
s: "test",
},
expected: "",
wantErr: true,
},
{
name: "asMap: nil structpb",
args: args{
tmpl: "{{ with $m := asMap . }}{{ $m.name }}{{ end }}",
s: nil,
},
expected: "",
wantErr: true,
},
{
name: "mapGet: valid with map[string]any",
args: args{
tmpl: "{{ mapGet . \"name\" }}",
s: map[string]any{
"name": "test",
},
},
expected: "test",
wantErr: false,
},
{
name: "mapGet: valid with asMapper",
args: args{
tmpl: "{{ mapGet . \"name\" }}",
s: &structpb.Struct{
Fields: map[string]*structpb.Value{
"name": {
Kind: &structpb.Value_StringValue{
StringValue: "test",
},
},
},
},
},
expected: "test",
wantErr: false,
},
{
name: "mapGet: using wrong key",
args: args{
tmpl: "{{ mapGet . \"name2\" }}",
s: map[string]any{
"name": "test",
},
},
expected: "",
wantErr: true,
},
{
name: "mapGet: using wrong type",
args: args{
tmpl: "{{ mapGet . \"name\" }}",
s: "test",
},
expected: "",
wantErr: true,
},
{
name: "mapGet: nil map",
args: args{
tmpl: "{{ mapGet . \"name\" }}",
s: nil,
},
expected: "",
wantErr: true,
},
}

for _, tt := range tests {
tt := tt

t.Run(tt.name, func(t *testing.T) {
t.Parallel()

tmpl, err := util.NewSafeTextTemplate(&tt.args.tmpl, "test")
// We're not testing the template parsing here
require.NoError(t, err, "unexpected error")

out, err := tmpl.Render(context.Background(), tt.args.s, limit)
if tt.wantErr {
assert.Error(t, err, "expected error")
} else {
assert.NoError(t, err, "unexpected error")
assert.Equal(t, tt.expected, out, "expected output")
}
})
}

}

0 comments on commit 2f6338f

Please sign in to comment.