From 639d9da3f52e8898a9ff2057f5f70beed0fe87fc Mon Sep 17 00:00:00 2001 From: Fantix King Date: Wed, 20 Mar 2019 18:34:34 -0500 Subject: [PATCH] feat(auth): added auth/request * Also completed auth/proxy --- arborist/auth.go | 63 ++++++++++++++++++++++++++++++++ arborist/resource_rules.go | 46 ++++++++++++----------- arborist/resource_rules.y | 46 ++++++++++++----------- arborist/resource_rules_test.go | 6 +-- arborist/server.go | 65 ++++++++++++++++++++++++++++++--- 5 files changed, 176 insertions(+), 50 deletions(-) diff --git a/arborist/auth.go b/arborist/auth.go index 32d9db51..1fd132a3 100644 --- a/arborist/auth.go +++ b/arborist/auth.go @@ -2,6 +2,8 @@ package arborist import ( "encoding/json" + "github.com/jmoiron/sqlx" + "github.com/lib/pq" ) type AuthRequestJSON struct { @@ -76,3 +78,64 @@ func (requestJSON *AuthRequestJSON_Request) UnmarshalJSON(data []byte) error { return nil } + +func authorize(db *sqlx.DB, token *TokenInfo, resource string, service string, method string) (bool, error) { + exp, args, err := Parse(resource) + if err != nil { + return false, err + } + + stmt := ` +SELECT coalesce(text2ltree("unnest") @> allowed, FALSE) FROM ( + SELECT array_agg(resource.path) AS allowed + FROM usr + LEFT JOIN usr_policy + ON usr_policy.usr_id = usr.id + LEFT JOIN policy_resource + ON policy_resource.policy_id = usr_policy.policy_id + LEFT JOIN resource + ON resource.id = policy_resource.resource_id + WHERE usr.name = $1 + AND EXISTS ( + SELECT 1 + FROM policy_role + LEFT JOIN permission + ON permission.role_id = policy_role.role_id + WHERE policy_role.policy_id = usr_policy.policy_id + AND permission.service = $2 + AND permission.method = $3 + ) + AND ($4 OR usr_policy.policy_id IN ( + SELECT id + FROM policy + WHERE policy.name = ANY($5) + )) +) _, unnest($6::text[]); +` + + rows, err := db.Query(stmt, + token.username, // $1 + service, // $2 + method, // $3 + len(token.policies) == 0, // $4 + pq.Array(token.policies), // $5 + pq.Array(args), // $6 + ) + if err != nil { + return false, err + } + + i := 0 + vars := make(map[string]bool) + for rows.Next() { + var result bool + err = rows.Scan(&result) + if err != nil { + return false, err + } + vars[args[i]] = result + i ++ + } + + return Eval(exp, vars) +} diff --git a/arborist/resource_rules.go b/arborist/resource_rules.go index fceda2b3..9c85e42a 100644 --- a/arborist/resource_rules.go +++ b/arborist/resource_rules.go @@ -72,7 +72,7 @@ const yyInitialStackSize = 16 type Lexer struct { scanner.Scanner - Vars map[string]interface{} + args map[string]interface{} result Expression error string } @@ -91,6 +91,7 @@ func (l *Lexer) Lex(lval *yySymType) int { default: if lit != "" && lit != "(" && lit != ")" { tok = VARIABLE + l.args[lit] = nil } } lval.token = Token{token: tok, literal: lit} @@ -101,26 +102,17 @@ func (l *Lexer) Error(e string) { l.error = e } -func (l *Lexer) Eval(e Expression) (bool, error) { +func Eval(e Expression, vars map[string]bool) (bool, error) { switch t := e.(type) { case Variable: - if v, ok := l.Vars[t.literal]; ok { - switch v.(type) { - case bool: - if v.(bool) { - return true, nil - } else { - return false, nil - } - default: - return false, errors.New("Parameter type must be boolean") - } + if v, ok := vars[t.literal]; ok { + return v, nil } return false, errors.New("Undefined symbol: " + t.literal) case ParenExpr: - return l.Eval(t.SubExpr) + return Eval(t.SubExpr, vars) case UnaryExpr: - right, err := l.Eval(t.right) + right, err := Eval(t.right, vars) if err != nil { return false, err } @@ -129,11 +121,11 @@ func (l *Lexer) Eval(e Expression) (bool, error) { return !right, nil } case AssocExpr: - left, err := l.Eval(t.left) + left, err := Eval(t.left, vars) if err != nil { return false, err } - right, err := l.Eval(t.right) + right, err := Eval(t.right, vars) if err != nil { return false, err } @@ -147,14 +139,26 @@ func (l *Lexer) Eval(e Expression) (bool, error) { return false, errors.New("Unexpected error") } -func Parse(exp string, vars map[string]interface{}) (bool, error) { +func Parse(exp string) (Expression, []string, error) { l := new(Lexer) - l.Vars = vars + l.args = make(map[string]interface{}) l.Init(strings.NewReader(exp)) if yyParse(l) != 0 { - return false, errors.New(l.error) + return nil, nil, errors.New(l.error) + } + args := make([]string, 0) + for arg := range l.args { + args = append(args, arg) + } + return l.result, args, nil +} + +func Run(exp string, vars map[string]bool) (bool, error) { + e, _, err := Parse(exp) + if err != nil { + return false, err } - return l.Eval(l.result) + return Eval(e, vars) } func init() { diff --git a/arborist/resource_rules.y b/arborist/resource_rules.y index 1c4d48de..563ff8d8 100644 --- a/arborist/resource_rules.y +++ b/arborist/resource_rules.y @@ -75,7 +75,7 @@ expr %% type Lexer struct { scanner.Scanner - Vars map[string]interface{} + args map[string]interface{} result Expression error string } @@ -94,6 +94,7 @@ func (l *Lexer) Lex(lval *yySymType) int { default: if lit != "" && lit != "(" && lit != ")" { tok = VARIABLE + l.args[lit] = nil } } lval.token = Token{token: tok, literal: lit} @@ -104,26 +105,17 @@ func (l *Lexer) Error(e string) { l.error = e } -func (l *Lexer) Eval(e Expression) (bool, error) { +func Eval(e Expression, vars map[string]bool) (bool, error) { switch t := e.(type) { case Variable: - if v, ok := l.Vars[t.literal]; ok { - switch v.(type) { - case bool: - if v.(bool) { - return true, nil - } else { - return false, nil - } - default: - return false, errors.New("Parameter type must be boolean") - } + if v, ok := vars[t.literal]; ok { + return v, nil } return false, errors.New("Undefined symbol: " + t.literal) case ParenExpr: - return l.Eval(t.SubExpr) + return Eval(t.SubExpr, vars) case UnaryExpr: - right, err := l.Eval(t.right) + right, err := Eval(t.right, vars) if err != nil { return false, err } @@ -132,11 +124,11 @@ func (l *Lexer) Eval(e Expression) (bool, error) { return ! right, nil } case AssocExpr: - left, err := l.Eval(t.left) + left, err := Eval(t.left, vars) if err != nil { return false, err } - right, err := l.Eval(t.right) + right, err := Eval(t.right, vars) if err != nil { return false, err } @@ -150,14 +142,26 @@ func (l *Lexer) Eval(e Expression) (bool, error) { return false, errors.New("Unexpected error") } -func Parse(exp string, vars map[string]interface{}) (bool, error) { +func Parse(exp string) (Expression, []string, error) { l := new(Lexer) - l.Vars = vars + l.args = make(map[string]interface{}) l.Init(strings.NewReader(exp)) if yyParse(l) != 0 { - return false, errors.New(l.error) + return nil, nil, errors.New(l.error) } - return l.Eval(l.result) + args := make([]string, 0) + for arg := range l.args { + args = append(args, arg) + } + return l.result, args, nil +} + +func Run(exp string, vars map[string]bool) (bool, error) { + e, _, err := Parse(exp) + if err != nil { + return false, err + } + return Eval(e, vars) } func init() { diff --git a/arborist/resource_rules_test.go b/arborist/resource_rules_test.go index f6a5176d..2290c9cf 100644 --- a/arborist/resource_rules_test.go +++ b/arborist/resource_rules_test.go @@ -5,8 +5,8 @@ import ( "testing" ) -func ParseOrFail(t *testing.T, exp string, vars map[string]interface{}) bool { - rv, err := Parse(exp, vars) +func ParseOrFail(t *testing.T, exp string, vars map[string]bool) bool { + rv, err := Run(exp, vars) if err != nil { t.Error(err) } @@ -24,7 +24,7 @@ func assertEqual(t *testing.T, a interface{}, b interface{}, message string) { } func TestParse(t *testing.T) { - vars := map[string]interface{}{ + vars := map[string]bool{ "T": true, "F": false, } diff --git a/arborist/server.go b/arborist/server.go index 34be0289..55898653 100644 --- a/arborist/server.go +++ b/arborist/server.go @@ -81,8 +81,8 @@ func (server *Server) MakeRouter(out io.Writer) http.Handler { router.HandleFunc("/health", server.handleHealth).Methods("GET") - //router.Handle("/auth/proxy", http.HandlerFunc(server.handleAuthProxy)).Methods("POST") - //router.Handle("/auth/request", server.handleAuthRequest).Methods("POST") + router.Handle("/auth/proxy", http.HandlerFunc(server.handleAuthProxy)).Methods("GET") + router.Handle("/auth/request", http.HandlerFunc(parseJSON(server.handleAuthRequest))).Methods("POST") //router.Handle("/auth/resources", server.handleListAuthResources).Methods("POST") router.Handle("/policy", http.HandlerFunc(server.handlePolicyList)).Methods("GET") @@ -133,7 +133,6 @@ func (server *Server) handleHealth(w http.ResponseWriter, r *http.Request) { w.WriteHeader(http.StatusOK) } -/* func (server *Server) handleAuthProxy(w http.ResponseWriter, r *http.Request) { // Get QS arguments resourcePathQS, ok := r.URL.Query()["resource"] @@ -183,9 +182,65 @@ func (server *Server) handleAuthProxy(w http.ResponseWriter, r *http.Request) { return } - // TODO + w.Header().Set("REMOTE_USER", info.username) + + rv, err := authorize(server.db, info, resourcePath, service, method) + if err != nil { + msg := fmt.Sprintf("could not authorize: %s", err.Error()) + server.logger.Info("tried to handle auth request but input was invalid: %s", msg) + response := newErrorResponse(msg, 400, nil) + _ = response.write(w, r) + return + } + if !rv { + errResponse := newErrorResponse( + "Unauthorized: user does not have access to this resource", 403, nil) + _ = errResponse.write(w, r) + } +} + +func (server *Server) handleAuthRequest(w http.ResponseWriter, r *http.Request, body []byte) { + authRequest := &AuthRequest{} + err := json.Unmarshal(body, authRequest) + if err != nil { + msg := fmt.Sprintf("could not parse auth request from JSON: %s", err.Error()) + server.logger.Info("tried to handle auth request but input was invalid: %s", msg) + response := newErrorResponse(msg, 400, nil) + _ = response.write(w, r) + return + } + + var aud []string + if authRequest.User.Audiences == nil { + aud = []string{"openid"} + } else { + aud = make([]string, len(authRequest.User.Audiences)) + copy(aud, authRequest.User.Audiences) + } + + info, err := server.decodeToken(authRequest.User.Token, aud) + if err != nil { + server.logger.Info(err.Error()) + errResponse := newErrorResponse(err.Error(), 401, &err) + _ = errResponse.write(w, r) + return + } + + if authRequest.User.Policies != nil { + info.policies = authRequest.User.Policies + } + + rv, err := authorize(server.db, info, authRequest.Request.Resource, + authRequest.Request.Action.Service, authRequest.Request.Action.Method) + if err != nil { + msg := fmt.Sprintf("could not authorize: %s", err.Error()) + server.logger.Info("tried to handle auth request but input was invalid: %s", msg) + response := newErrorResponse(msg, 400, nil) + _ = response.write(w, r) + return + } + _ = jsonResponseFrom(AuthResponse{rv}, 200).write(w, r) } -*/ func (server *Server) handlePolicyList(w http.ResponseWriter, r *http.Request) { policies, err := listPoliciesFromDb(server.db)