Skip to content

Commit

Permalink
feat(auth): added auth/request
Browse files Browse the repository at this point in the history
* Also completed auth/proxy
  • Loading branch information
fantix committed Mar 21, 2019
1 parent 353702c commit 639d9da
Show file tree
Hide file tree
Showing 5 changed files with 176 additions and 50 deletions.
63 changes: 63 additions & 0 deletions arborist/auth.go
Original file line number Diff line number Diff line change
Expand Up @@ -2,6 +2,8 @@ package arborist

import (
"encoding/json"
"github.com/jmoiron/sqlx"
"github.com/lib/pq"
)

type AuthRequestJSON struct {
Expand Down Expand Up @@ -76,3 +78,64 @@ func (requestJSON *AuthRequestJSON_Request) UnmarshalJSON(data []byte) error {

return nil
}

func authorize(db *sqlx.DB, token *TokenInfo, resource string, service string, method string) (bool, error) {
exp, args, err := Parse(resource)
if err != nil {
return false, err
}

stmt := `
SELECT coalesce(text2ltree("unnest") @> allowed, FALSE) FROM (
SELECT array_agg(resource.path) AS allowed
FROM usr
LEFT JOIN usr_policy
ON usr_policy.usr_id = usr.id
LEFT JOIN policy_resource
ON policy_resource.policy_id = usr_policy.policy_id
LEFT JOIN resource
ON resource.id = policy_resource.resource_id
WHERE usr.name = $1
AND EXISTS (
SELECT 1
FROM policy_role
LEFT JOIN permission
ON permission.role_id = policy_role.role_id
WHERE policy_role.policy_id = usr_policy.policy_id
AND permission.service = $2
AND permission.method = $3
)
AND ($4 OR usr_policy.policy_id IN (
SELECT id
FROM policy
WHERE policy.name = ANY($5)
))
) _, unnest($6::text[]);
`

rows, err := db.Query(stmt,
token.username, // $1
service, // $2
method, // $3
len(token.policies) == 0, // $4
pq.Array(token.policies), // $5
pq.Array(args), // $6
)
if err != nil {
return false, err
}

i := 0
vars := make(map[string]bool)
for rows.Next() {
var result bool
err = rows.Scan(&result)
if err != nil {
return false, err
}
vars[args[i]] = result
i ++
}

return Eval(exp, vars)
}
46 changes: 25 additions & 21 deletions arborist/resource_rules.go

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

46 changes: 25 additions & 21 deletions arborist/resource_rules.y
Original file line number Diff line number Diff line change
Expand Up @@ -75,7 +75,7 @@ expr
%%
type Lexer struct {
scanner.Scanner
Vars map[string]interface{}
args map[string]interface{}
result Expression
error string
}
Expand All @@ -94,6 +94,7 @@ func (l *Lexer) Lex(lval *yySymType) int {
default:
if lit != "" && lit != "(" && lit != ")" {
tok = VARIABLE
l.args[lit] = nil
}
}
lval.token = Token{token: tok, literal: lit}
Expand All @@ -104,26 +105,17 @@ func (l *Lexer) Error(e string) {
l.error = e
}

func (l *Lexer) Eval(e Expression) (bool, error) {
func Eval(e Expression, vars map[string]bool) (bool, error) {
switch t := e.(type) {
case Variable:
if v, ok := l.Vars[t.literal]; ok {
switch v.(type) {
case bool:
if v.(bool) {
return true, nil
} else {
return false, nil
}
default:
return false, errors.New("Parameter type must be boolean")
}
if v, ok := vars[t.literal]; ok {
return v, nil
}
return false, errors.New("Undefined symbol: " + t.literal)
case ParenExpr:
return l.Eval(t.SubExpr)
return Eval(t.SubExpr, vars)
case UnaryExpr:
right, err := l.Eval(t.right)
right, err := Eval(t.right, vars)
if err != nil {
return false, err
}
Expand All @@ -132,11 +124,11 @@ func (l *Lexer) Eval(e Expression) (bool, error) {
return ! right, nil
}
case AssocExpr:
left, err := l.Eval(t.left)
left, err := Eval(t.left, vars)
if err != nil {
return false, err
}
right, err := l.Eval(t.right)
right, err := Eval(t.right, vars)
if err != nil {
return false, err
}
Expand All @@ -150,14 +142,26 @@ func (l *Lexer) Eval(e Expression) (bool, error) {
return false, errors.New("Unexpected error")
}

func Parse(exp string, vars map[string]interface{}) (bool, error) {
func Parse(exp string) (Expression, []string, error) {
l := new(Lexer)
l.Vars = vars
l.args = make(map[string]interface{})
l.Init(strings.NewReader(exp))
if yyParse(l) != 0 {
return false, errors.New(l.error)
return nil, nil, errors.New(l.error)
}
return l.Eval(l.result)
args := make([]string, 0)
for arg := range l.args {
args = append(args, arg)
}
return l.result, args, nil
}

func Run(exp string, vars map[string]bool) (bool, error) {
e, _, err := Parse(exp)
if err != nil {
return false, err
}
return Eval(e, vars)
}

func init() {
Expand Down
6 changes: 3 additions & 3 deletions arborist/resource_rules_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,8 +5,8 @@ import (
"testing"
)

func ParseOrFail(t *testing.T, exp string, vars map[string]interface{}) bool {
rv, err := Parse(exp, vars)
func ParseOrFail(t *testing.T, exp string, vars map[string]bool) bool {
rv, err := Run(exp, vars)
if err != nil {
t.Error(err)
}
Expand All @@ -24,7 +24,7 @@ func assertEqual(t *testing.T, a interface{}, b interface{}, message string) {
}

func TestParse(t *testing.T) {
vars := map[string]interface{}{
vars := map[string]bool{
"T": true,
"F": false,
}
Expand Down
65 changes: 60 additions & 5 deletions arborist/server.go
Original file line number Diff line number Diff line change
Expand Up @@ -81,8 +81,8 @@ func (server *Server) MakeRouter(out io.Writer) http.Handler {

router.HandleFunc("/health", server.handleHealth).Methods("GET")

//router.Handle("/auth/proxy", http.HandlerFunc(server.handleAuthProxy)).Methods("POST")
//router.Handle("/auth/request", server.handleAuthRequest).Methods("POST")
router.Handle("/auth/proxy", http.HandlerFunc(server.handleAuthProxy)).Methods("GET")
router.Handle("/auth/request", http.HandlerFunc(parseJSON(server.handleAuthRequest))).Methods("POST")
//router.Handle("/auth/resources", server.handleListAuthResources).Methods("POST")

router.Handle("/policy", http.HandlerFunc(server.handlePolicyList)).Methods("GET")
Expand Down Expand Up @@ -133,7 +133,6 @@ func (server *Server) handleHealth(w http.ResponseWriter, r *http.Request) {
w.WriteHeader(http.StatusOK)
}

/*
func (server *Server) handleAuthProxy(w http.ResponseWriter, r *http.Request) {
// Get QS arguments
resourcePathQS, ok := r.URL.Query()["resource"]
Expand Down Expand Up @@ -183,9 +182,65 @@ func (server *Server) handleAuthProxy(w http.ResponseWriter, r *http.Request) {
return
}

// TODO
w.Header().Set("REMOTE_USER", info.username)

rv, err := authorize(server.db, info, resourcePath, service, method)
if err != nil {
msg := fmt.Sprintf("could not authorize: %s", err.Error())
server.logger.Info("tried to handle auth request but input was invalid: %s", msg)
response := newErrorResponse(msg, 400, nil)
_ = response.write(w, r)
return
}
if !rv {
errResponse := newErrorResponse(
"Unauthorized: user does not have access to this resource", 403, nil)
_ = errResponse.write(w, r)
}
}

func (server *Server) handleAuthRequest(w http.ResponseWriter, r *http.Request, body []byte) {
authRequest := &AuthRequest{}
err := json.Unmarshal(body, authRequest)
if err != nil {
msg := fmt.Sprintf("could not parse auth request from JSON: %s", err.Error())
server.logger.Info("tried to handle auth request but input was invalid: %s", msg)
response := newErrorResponse(msg, 400, nil)
_ = response.write(w, r)
return
}

var aud []string
if authRequest.User.Audiences == nil {
aud = []string{"openid"}
} else {
aud = make([]string, len(authRequest.User.Audiences))
copy(aud, authRequest.User.Audiences)
}

info, err := server.decodeToken(authRequest.User.Token, aud)
if err != nil {
server.logger.Info(err.Error())
errResponse := newErrorResponse(err.Error(), 401, &err)
_ = errResponse.write(w, r)
return
}

if authRequest.User.Policies != nil {
info.policies = authRequest.User.Policies
}

rv, err := authorize(server.db, info, authRequest.Request.Resource,
authRequest.Request.Action.Service, authRequest.Request.Action.Method)
if err != nil {
msg := fmt.Sprintf("could not authorize: %s", err.Error())
server.logger.Info("tried to handle auth request but input was invalid: %s", msg)
response := newErrorResponse(msg, 400, nil)
_ = response.write(w, r)
return
}
_ = jsonResponseFrom(AuthResponse{rv}, 200).write(w, r)
}
*/

func (server *Server) handlePolicyList(w http.ResponseWriter, r *http.Request) {
policies, err := listPoliciesFromDb(server.db)
Expand Down

0 comments on commit 639d9da

Please sign in to comment.