diff --git a/build/scripts/find-lint.sh b/build/scripts/find-lint.sh index af87e14d7f..e3564ae380 100755 --- a/build/scripts/find-lint.sh +++ b/build/scripts/find-lint.sh @@ -33,7 +33,7 @@ echo "Looking for lint..." # Capture exit code to ensure go.{mod,sum} is restored before exiting exit_code=0 -golangci-lint run $args || exit_code=1 +PATH="$PATH:${GOPATH:-~/go}/bin" golangci-lint run $args || exit_code=1 # Restore go.{mod,sum} mv go.mod.bak go.mod && mv go.sum.bak go.sum diff --git a/clientapi/auth/auth.go b/clientapi/auth/auth.go index b4c39ae38d..8fd8511845 100644 --- a/clientapi/auth/auth.go +++ b/clientapi/auth/auth.go @@ -42,6 +42,8 @@ type DeviceDatabase interface { type AccountDatabase interface { // Look up the account matching the given localpart. GetAccountByLocalpart(ctx context.Context, localpart string) (*api.Account, error) + GetAccountByPassword(ctx context.Context, localpart, password string) (*api.Account, error) + GetLocalpartForThreePID(ctx context.Context, address, medium string) (string, error) } // VerifyUserFromRequest authenticates the HTTP request, diff --git a/clientapi/auth/authtypes/logintypes.go b/clientapi/auth/authtypes/logintypes.go index da0324251e..2b19b355c9 100644 --- a/clientapi/auth/authtypes/logintypes.go +++ b/clientapi/auth/authtypes/logintypes.go @@ -10,4 +10,6 @@ const ( LoginTypeSharedSecret = "org.matrix.login.shared_secret" LoginTypeRecaptcha = "m.login.recaptcha" LoginTypeApplicationService = "m.login.application_service" + LoginTypeSSO = "m.login.sso" + LoginTypeToken = "m.login.token" ) diff --git a/clientapi/auth/login.go b/clientapi/auth/login.go new file mode 100644 index 0000000000..ec8bf3b2ff --- /dev/null +++ b/clientapi/auth/login.go @@ -0,0 +1,83 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "encoding/json" + "io" + "io/ioutil" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +// LoginFromJSONReader performs authentication given a login request body reader and +// some context. It returns the basic login information and a cleanup function to be +// called after authorization has completed, with the result of the authorization. +// If the final return value is non-nil, an error occurred and the cleanup function +// is nil. +func LoginFromJSONReader(ctx context.Context, r io.Reader, accountDB AccountDatabase, userAPI UserInternalAPIForLogin, cfg *config.ClientAPI) (*Login, LoginCleanupFunc, *util.JSONResponse) { + reqBytes, err := ioutil.ReadAll(r) + if err != nil { + err := &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("Reading request body failed: " + err.Error()), + } + return nil, nil, err + } + + var header struct { + Type string `json:"type"` + } + if err := json.Unmarshal(reqBytes, &header); err != nil { + err := &util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.BadJSON("Reading request body failed: " + err.Error()), + } + return nil, nil, err + } + + var typ Type + switch header.Type { + case authtypes.LoginTypePassword: + typ = &LoginTypePassword{ + GetAccountByPassword: accountDB.GetAccountByPassword, + Config: cfg, + } + case authtypes.LoginTypeToken: + typ = &LoginTypeToken{ + UserAPI: userAPI, + Config: cfg, + } + default: + err := util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("unhandled login type: " + header.Type), + } + return nil, nil, &err + } + + return typ.LoginFromJSON(ctx, reqBytes) +} + +// UserInternalAPIForLogin contains the aspects of UserAPI required for logging in. +type UserInternalAPIForLogin interface { + uapi.LoginTokenInternalAPI +} diff --git a/clientapi/auth/login_test.go b/clientapi/auth/login_test.go new file mode 100644 index 0000000000..d64d05ff94 --- /dev/null +++ b/clientapi/auth/login_test.go @@ -0,0 +1,131 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "reflect" + "strings" + "testing" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" +) + +func TestLoginFromJSONReader(t *testing.T) { + ctx := context.Background() + + tsts := []struct { + Name string + Body string + + WantErrCode string + WantUsername string + WantDeviceID string + WantDeletedTokens []string + }{ + {Name: "empty", WantErrCode: "M_BAD_JSON"}, + { + Name: "passwordWorks", + Body: `{ + "type": "m.login.password", + "identifier": { "type": "m.id.user", "user": "alice" }, + "password": "herpassword", + "device_id": "adevice" + }`, + WantUsername: "alice", + WantDeviceID: "adevice", + }, + { + Name: "tokenWorks", + Body: `{ + "type": "m.login.token", + "token": "atoken", + "device_id": "adevice" + }`, + WantUsername: "@auser:example.com", + WantDeviceID: "adevice", + WantDeletedTokens: []string{"atoken"}, + }, + } + for _, tst := range tsts { + t.Run(tst.Name, func(t *testing.T) { + var accountDB fakeAccountDB + var userAPI fakeUserInternalAPI + cfg := &config.ClientAPI{ + Matrix: &config.Global{ + ServerName: serverName, + }, + } + login, cleanup, errRes := LoginFromJSONReader(ctx, strings.NewReader(tst.Body), &accountDB, &userAPI, cfg) + if tst.WantErrCode == "" { + if errRes != nil { + t.Fatalf("LoginFromJSONReader failed: %+v", errRes) + } + cleanup(ctx, nil) + } else { + if errRes == nil { + t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) + } else if merr, ok := errRes.JSON.(*jsonerror.MatrixError); ok && merr.ErrCode != tst.WantErrCode { + t.Fatalf("LoginFromJSONReader err: got %+v, want code %q", errRes, tst.WantErrCode) + } + return + } + + if login.Username() != tst.WantUsername { + t.Errorf("Username: got %q, want %q", login.Username(), tst.WantUsername) + } + + if login.DeviceID == nil { + if tst.WantDeviceID != "" { + t.Errorf("DeviceID: got %v, want %q", login.DeviceID, tst.WantDeviceID) + } + } else { + if *login.DeviceID != tst.WantDeviceID { + t.Errorf("DeviceID: got %q, want %q", *login.DeviceID, tst.WantDeviceID) + } + } + + if !reflect.DeepEqual(userAPI.DeletedTokens, tst.WantDeletedTokens) { + t.Errorf("DeletedTokens: got %+v, want %+v", userAPI.DeletedTokens, tst.WantDeletedTokens) + } + }) + } +} + +type fakeAccountDB struct { + AccountDatabase +} + +func (*fakeAccountDB) GetAccountByPassword(ctx context.Context, localpart, password string) (*uapi.Account, error) { + return &uapi.Account{}, nil +} + +type fakeUserInternalAPI struct { + UserInternalAPIForLogin + + DeletedTokens []string +} + +func (ua *fakeUserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *uapi.PerformLoginTokenDeletionRequest, res *uapi.PerformLoginTokenDeletionResponse) error { + ua.DeletedTokens = append(ua.DeletedTokens, req.Token) + return nil +} + +func (*fakeUserInternalAPI) QueryLoginToken(ctx context.Context, req *uapi.QueryLoginTokenRequest, res *uapi.QueryLoginTokenResponse) error { + res.Data = &uapi.LoginTokenData{UserID: "@auser:example.com"} + return nil +} diff --git a/clientapi/auth/login_token.go b/clientapi/auth/login_token.go new file mode 100644 index 0000000000..f9c57a48f8 --- /dev/null +++ b/clientapi/auth/login_token.go @@ -0,0 +1,84 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package auth + +import ( + "context" + "net/http" + + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +// LoginTypeToken describes how to authenticate with a login token. +type LoginTypeToken struct { + UserAPI uapi.LoginTokenInternalAPI + Config *config.ClientAPI +} + +// Name implements Type. +func (t *LoginTypeToken) Name() string { + return authtypes.LoginTypeToken +} + +// LoginFromJSON implements Type. The cleanup function deletes the token from +// the database on success. +func (t *LoginTypeToken) LoginFromJSON(ctx context.Context, reqBytes []byte) (*Login, LoginCleanupFunc, *util.JSONResponse) { + var r loginTokenRequest + if err := httputil.UnmarshalJSON(reqBytes, &r); err != nil { + return nil, nil, err + } + + return t.login(ctx, &r) +} + +// loginTokenRequest struct to hold the possible parameters from an HTTP request. +type loginTokenRequest struct { + Login + Token string `json:"token"` +} + +// login parses and validates the login token. It returns basic user information. +func (t *LoginTypeToken) login(ctx context.Context, r *loginTokenRequest) (*Login, LoginCleanupFunc, *util.JSONResponse) { + var res uapi.QueryLoginTokenResponse + if err := t.UserAPI.QueryLoginToken(ctx, &uapi.QueryLoginTokenRequest{Token: r.Token}, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("UserAPI.QueryLoginToken failed") + jsonErr := jsonerror.InternalServerError() + return nil, nil, &jsonErr + } + if res.Data == nil { + return nil, nil, &util.JSONResponse{ + Code: http.StatusForbidden, + JSON: jsonerror.Forbidden("invalid login token"), + } + } + + r.Login.Identifier.Type = "m.id.user" + r.Login.Identifier.User = res.Data.UserID + + cleanup := func(ctx context.Context, authRes *util.JSONResponse) { + if authRes == nil || authRes.Code == http.StatusOK { + var res uapi.PerformLoginTokenDeletionResponse + if err := t.UserAPI.PerformLoginTokenDeletion(ctx, &uapi.PerformLoginTokenDeletionRequest{Token: r.Token}, &res); err != nil { + util.GetLogger(ctx).WithError(err).Error("UserAPI.PerformLoginTokenDeletion failed") + } + } + } + return &r.Login, cleanup, nil +} diff --git a/clientapi/auth/password.go b/clientapi/auth/password.go index a66e2fe76e..ce1cb82f3e 100644 --- a/clientapi/auth/password.go +++ b/clientapi/auth/password.go @@ -18,6 +18,8 @@ import ( "context" "net/http" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/httputil" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" @@ -39,15 +41,24 @@ type LoginTypePassword struct { } func (t *LoginTypePassword) Name() string { - return "m.login.password" + return authtypes.LoginTypePassword } -func (t *LoginTypePassword) Request() interface{} { - return &PasswordRequest{} +func (t *LoginTypePassword) LoginFromJSON(ctx context.Context, reqBytes []byte) (*Login, LoginCleanupFunc, *util.JSONResponse) { + var r PasswordRequest + if err := httputil.UnmarshalJSON(reqBytes, &r); err != nil { + return nil, nil, err + } + + login, err := t.Login(ctx, &r) + if err != nil { + return nil, nil, err + } + + return login, func(context.Context, *util.JSONResponse) {}, nil } -func (t *LoginTypePassword) Login(ctx context.Context, req interface{}) (*Login, *util.JSONResponse) { - r := req.(*PasswordRequest) +func (t *LoginTypePassword) Login(ctx context.Context, r *PasswordRequest) (*Login, *util.JSONResponse) { username := r.Username() if username == "" { return nil, &util.JSONResponse{ diff --git a/clientapi/auth/sso/github.go b/clientapi/auth/sso/github.go new file mode 100644 index 0000000000..9ef5cd2daf --- /dev/null +++ b/clientapi/auth/sso/github.go @@ -0,0 +1,37 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sso + +import ( + "github.com/matrix-org/dendrite/setup/config" +) + +// GitHubIdentityProvider is a GitHub-flavored identity provider. +var GitHubIdentityProvider IdentityProvider = githubIdentityProvider{ + baseOIDCIdentityProvider: &baseOIDCIdentityProvider{ + AuthURL: mustParseURLTemplate("https://github.com/login/oauth/authorize?scope=user:email"), + AccessTokenURL: mustParseURLTemplate("https://github.com/login/oauth/access_token"), + UserInfoURL: mustParseURLTemplate("https://api.github.com/user"), + UserInfoAccept: "application/vnd.github.v3+json", + UserInfoEmailPath: "email", + UserInfoSuggestedUserIDPath: "login", + }, +} + +type githubIdentityProvider struct { + *baseOIDCIdentityProvider +} + +func (githubIdentityProvider) DefaultBrand() string { return config.SSOBrandGitHub } diff --git a/clientapi/auth/sso/oidc_base.go b/clientapi/auth/sso/oidc_base.go new file mode 100644 index 0000000000..edc600d4c4 --- /dev/null +++ b/clientapi/auth/sso/oidc_base.go @@ -0,0 +1,262 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sso + +import ( + "context" + "encoding/json" + "fmt" + "io/ioutil" + "mime" + "net/http" + "net/url" + "strings" + "text/template" + + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/tidwall/gjson" +) + +type baseOIDCIdentityProvider struct { + AuthURL *urlTemplate + AccessTokenURL *urlTemplate + UserInfoURL *urlTemplate + UserInfoAccept string + UserInfoEmailPath string + UserInfoSuggestedUserIDPath string +} + +func (p *baseOIDCIdentityProvider) AuthorizationURL(ctx context.Context, req *IdentityProviderRequest) (string, error) { + u, err := p.AuthURL.Execute(map[string]interface{}{ + "Config": req.System, + "State": req.DendriteNonce, + "RedirectURI": req.CallbackURL, + }, url.Values{ + "client_id": []string{req.System.OIDC.ClientID}, + "response_type": []string{"code"}, + "redirect_uri": []string{req.CallbackURL}, + "state": []string{req.DendriteNonce}, + }) + if err != nil { + return "", err + } + return u.String(), nil +} + +func (p *baseOIDCIdentityProvider) ProcessCallback(ctx context.Context, req *IdentityProviderRequest, values url.Values) (*CallbackResult, error) { + state := values.Get("state") + if state == "" { + return nil, jsonerror.MissingArgument("state parameter missing") + } + if state != req.DendriteNonce { + return nil, jsonerror.InvalidArgumentValue("state parameter not matching nonce") + } + + if error := values.Get("error"); error != "" { + if euri := values.Get("error_uri"); euri != "" { + return &CallbackResult{RedirectURL: euri}, nil + } + + desc := values.Get("error_description") + if desc == "" { + desc = error + } + switch error { + case "unauthorized_client", "access_denied": + return nil, jsonerror.Forbidden("SSO said no: " + desc) + default: + return nil, fmt.Errorf("SSO failed: %v", error) + } + } + + code := values.Get("code") + if code == "" { + return nil, jsonerror.MissingArgument("code parameter missing") + } + + oidcAccessToken, err := p.getOIDCAccessToken(ctx, req, code) + if err != nil { + return nil, err + } + + id, userID, err := p.getUserInfo(ctx, req, oidcAccessToken) + if err != nil { + return nil, err + } + + return &CallbackResult{Identifier: id, SuggestedUserID: userID}, nil +} + +func (p *baseOIDCIdentityProvider) getOIDCAccessToken(ctx context.Context, req *IdentityProviderRequest, code string) (string, error) { + u, err := p.AccessTokenURL.Execute(nil, nil) + if err != nil { + return "", err + } + + body := url.Values{ + "grant_type": []string{"authorization_code"}, + "code": []string{code}, + "redirect_uri": []string{req.CallbackURL}, + "client_id": []string{req.System.OIDC.ClientID}, + } + + hreq, err := http.NewRequestWithContext(ctx, http.MethodPost, u.String(), strings.NewReader(body.Encode())) + if err != nil { + return "", err + } + hreq.Header.Set("Content-Type", "application/x-www-form-urlencoded") + hreq.Header.Set("Accept", "application/x-www-form-urlencoded") + + hresp, err := http.DefaultClient.Do(hreq) + if err != nil { + return "", err + } + defer hresp.Body.Close() + + ctype, _, err := mime.ParseMediaType(hresp.Header.Get("Content-Type")) + if err != nil { + return "", err + } + if ctype != "application/json" { + return "", fmt.Errorf("expected URL encoded response, got content type %q", ctype) + } + + var resp struct { + TokenType string `json:"token_type"` + AccessToken string `json:"access_token"` + + Error string `json:"error"` + ErrorDescription string `json:"error_description"` + ErrorURI string `json:"error_uri"` + } + if err := json.NewDecoder(hresp.Body).Decode(&resp); err != nil { + return "", err + } + + if resp.Error != "" { + desc := resp.ErrorDescription + if desc == "" { + desc = resp.Error + } + return "", fmt.Errorf("failed to retrieve OIDC access token: %s", desc) + } + + if strings.ToLower(resp.TokenType) != "bearer" { + return "", fmt.Errorf("expected bearer token, got type %q", resp.TokenType) + } + + return resp.AccessToken, nil +} + +func (p *baseOIDCIdentityProvider) getUserInfo(ctx context.Context, req *IdentityProviderRequest, oidcAccessToken string) (userutil.Identifier, string, error) { + u, err := p.UserInfoURL.Execute(map[string]interface{}{ + "Config": req.System, + }, nil) + if err != nil { + return nil, "", err + } + + hreq, err := http.NewRequestWithContext(ctx, http.MethodGet, u.String(), nil) + if err != nil { + return nil, "", err + } + hreq.Header.Set("Authorization", "token "+oidcAccessToken) + hreq.Header.Set("Accept", p.UserInfoAccept) + + hresp, err := http.DefaultClient.Do(hreq) + if err != nil { + return nil, "", err + } + defer hresp.Body.Close() + + ctype, _, err := mime.ParseMediaType(hresp.Header.Get("Content-Type")) + if err != nil { + return nil, "", err + } + + var email string + var suggestedUserID string + switch ctype { + case "application/json": + body, err := ioutil.ReadAll(hresp.Body) + if err != nil { + return nil, "", err + } + + emailRes := gjson.GetBytes(body, p.UserInfoEmailPath) + if !emailRes.Exists() { + return nil, "", fmt.Errorf("no email in user info response body") + } + email = emailRes.String() + + // This is optional. + userIDRes := gjson.GetBytes(body, p.UserInfoSuggestedUserIDPath) + suggestedUserID = userIDRes.String() + + default: + return nil, "", fmt.Errorf("got unknown content type %q for user info", ctype) + } + + if email == "" { + return nil, "", fmt.Errorf("no email address in user info") + } + + return &userutil.ThirdPartyIdentifier{Medium: "email", Address: email}, suggestedUserID, nil +} + +type urlTemplate struct { + base *template.Template +} + +func parseURLTemplate(s string) (*urlTemplate, error) { + t, err := template.New("").Parse(s) + if err != nil { + return nil, err + } + return &urlTemplate{base: t}, nil +} + +func mustParseURLTemplate(s string) *urlTemplate { + t, err := parseURLTemplate(s) + if err != nil { + panic(err) + } + return t +} + +func (t *urlTemplate) Execute(params interface{}, defaultQuery url.Values) (*url.URL, error) { + var sb strings.Builder + err := t.base.Execute(&sb, params) + if err != nil { + return nil, err + } + + u, err := url.Parse(sb.String()) + if err != nil { + return nil, err + } + + if defaultQuery != nil { + q := u.Query() + for k, vs := range defaultQuery { + if q.Get(k) == "" { + q[k] = vs + } + } + u.RawQuery = q.Encode() + } + return u, nil +} diff --git a/clientapi/auth/sso/sso.go b/clientapi/auth/sso/sso.go new file mode 100644 index 0000000000..1b9215983d --- /dev/null +++ b/clientapi/auth/sso/sso.go @@ -0,0 +1,57 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sso + +import ( + "context" + "net/url" + + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/setup/config" +) + +type IdentityProvider interface { + DefaultBrand() string + + AuthorizationURL(context.Context, *IdentityProviderRequest) (string, error) + ProcessCallback(context.Context, *IdentityProviderRequest, url.Values) (*CallbackResult, error) +} + +type IdentityProviderRequest struct { + System *config.IdentityProvider + CallbackURL string + DendriteNonce string +} + +type CallbackResult struct { + RedirectURL string + Identifier userutil.Identifier + SuggestedUserID string +} + +type IdentityProviderType string + +const ( + TypeGitHub IdentityProviderType = config.SSOBrandGitHub +) + +func GetIdentityProvider(t IdentityProviderType) IdentityProvider { + switch t { + case TypeGitHub: + return GitHubIdentityProvider + default: + return nil + } +} diff --git a/clientapi/auth/user_interactive.go b/clientapi/auth/user_interactive.go index 30469fc474..9cab7956c6 100644 --- a/clientapi/auth/user_interactive.go +++ b/clientapi/auth/user_interactive.go @@ -32,22 +32,24 @@ import ( type Type interface { // Name returns the name of the auth type e.g `m.login.password` Name() string - // Request returns a pointer to a new request body struct to unmarshal into. - Request() interface{} // Login with the auth type, returning an error response on failure. // Not all types support login, only m.login.password and m.login.token // See https://matrix.org/docs/spec/client_server/r0.6.1#post-matrix-client-r0-login - // `req` is guaranteed to be the type returned from Request() // This function will be called when doing login and when doing 'sudo' style // actions e.g deleting devices. The response must be a 401 as per: // "If the homeserver decides that an attempt on a stage was unsuccessful, but the // client may make a second attempt, it returns the same HTTP status 401 response as above, // with the addition of the standard errcode and error fields describing the error." - Login(ctx context.Context, req interface{}) (login *Login, errRes *util.JSONResponse) + // + // The returned cleanup function must be non-nil on success, and will be called after + // authorization has been completed. Its argument is the final result of authorization. + LoginFromJSON(ctx context.Context, reqBytes []byte) (login *Login, cleanup LoginCleanupFunc, errRes *util.JSONResponse) // TODO: Extend to support Register() flow // Register(ctx context.Context, sessionID string, req interface{}) } +type LoginCleanupFunc func(context.Context, *util.JSONResponse) + // LoginIdentifier represents identifier types // https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types type LoginIdentifier struct { @@ -61,11 +63,8 @@ type LoginIdentifier struct { // Login represents the shared fields used in all forms of login/sudo endpoints. type Login struct { - Type string `json:"type"` - Identifier LoginIdentifier `json:"identifier"` - User string `json:"user"` // deprecated in favour of identifier - Medium string `json:"medium"` // deprecated in favour of identifier - Address string `json:"address"` // deprecated in favour of identifier + LoginIdentifier // Flat fields deprecated in favour of `identifier`. + Identifier LoginIdentifier `json:"identifier"` // Both DeviceID and InitialDisplayName can be omitted, or empty strings ("") // Thus a pointer is needed to differentiate between the two @@ -111,12 +110,11 @@ type UserInteractive struct { Sessions map[string][]string } -func NewUserInteractive(getAccByPass GetAccountByPassword, cfg *config.ClientAPI) *UserInteractive { +func NewUserInteractive(accountDB AccountDatabase, cfg *config.ClientAPI) *UserInteractive { typePassword := &LoginTypePassword{ - GetAccountByPassword: getAccByPass, + GetAccountByPassword: accountDB.GetAccountByPassword, Config: cfg, } - // TODO: Add SSO login return &UserInteractive{ Completed: []string{}, Flows: []userInteractiveFlow{ @@ -236,18 +234,13 @@ func (u *UserInteractive) Verify(ctx context.Context, bodyBytes []byte, device * } } - r := loginType.Request() - if err := json.Unmarshal([]byte(gjson.GetBytes(bodyBytes, "auth").Raw), r); err != nil { - return nil, &util.JSONResponse{ - Code: http.StatusBadRequest, - JSON: jsonerror.BadJSON("The request body could not be decoded into valid JSON. " + err.Error()), - } - } - login, resErr := loginType.Login(ctx, r) - if resErr == nil { - u.AddCompletedStage(sessionID, authType) - // TODO: Check if there's more stages to go and return an error - return login, nil + login, cleanup, resErr := loginType.LoginFromJSON(ctx, []byte(gjson.GetBytes(bodyBytes, "auth").Raw)) + if resErr != nil { + return nil, u.ResponseWithChallenge(sessionID, resErr.JSON) } - return nil, u.ResponseWithChallenge(sessionID, resErr.JSON) + + u.AddCompletedStage(sessionID, authType) + cleanup(ctx, nil) + // TODO: Check if there's more stages to go and return an error + return login, nil } diff --git a/clientapi/auth/user_interactive_test.go b/clientapi/auth/user_interactive_test.go index 0b7df35453..76d161a74f 100644 --- a/clientapi/auth/user_interactive_test.go +++ b/clientapi/auth/user_interactive_test.go @@ -24,7 +24,11 @@ var ( } ) -func getAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) { +type fakeAccountDatabase struct { + AccountDatabase +} + +func (*fakeAccountDatabase) GetAccountByPassword(ctx context.Context, localpart, plaintextPassword string) (*api.Account, error) { acc, ok := lookup[localpart+" "+plaintextPassword] if !ok { return nil, fmt.Errorf("unknown user/password") @@ -38,7 +42,7 @@ func setup() *UserInteractive { ServerName: serverName, }, } - return NewUserInteractive(getAccountByPassword, cfg) + return NewUserInteractive(&fakeAccountDatabase{}, cfg) } func TestUserInteractiveChallenge(t *testing.T) { diff --git a/clientapi/httputil/httputil.go b/clientapi/httputil/httputil.go index 29d7b0b37f..b47701368b 100644 --- a/clientapi/httputil/httputil.go +++ b/clientapi/httputil/httputil.go @@ -36,6 +36,10 @@ func UnmarshalJSONRequest(req *http.Request, iface interface{}) *util.JSONRespon return &resp } + return UnmarshalJSON(body, iface) +} + +func UnmarshalJSON(body []byte, iface interface{}) *util.JSONResponse { if !utf8.Valid(body) { return &util.JSONResponse{ Code: http.StatusBadRequest, diff --git a/clientapi/routing/login.go b/clientapi/routing/login.go index 589efe0b2f..f7277cd26f 100644 --- a/clientapi/routing/login.go +++ b/clientapi/routing/login.go @@ -19,11 +19,12 @@ import ( "net/http" "github.com/matrix-org/dendrite/clientapi/auth" - "github.com/matrix-org/dendrite/clientapi/httputil" + "github.com/matrix-org/dendrite/clientapi/auth/authtypes" + "github.com/matrix-org/dendrite/clientapi/auth/sso" "github.com/matrix-org/dendrite/clientapi/jsonerror" "github.com/matrix-org/dendrite/clientapi/userutil" "github.com/matrix-org/dendrite/setup/config" - userapi "github.com/matrix-org/dendrite/userapi/api" + uapi "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/storage/accounts" "github.com/matrix-org/gomatrixserverlib" "github.com/matrix-org/util" @@ -37,50 +38,81 @@ type loginResponse struct { } type flows struct { - Flows []flow `json:"flows"` + Flows []stage `json:"flows"` } -type flow struct { - Type string `json:"type"` +type stage struct { + Type string `json:"type"` + IdentityProviders []identityProvider `json:"identity_providers,omitempty"` } -func passwordLogin() flows { - f := flows{} - s := flow{ - Type: "m.login.password", +type identityProvider struct { + ID string `json:"id"` + Name string `json:"name"` + Brand string `json:"brand,omitempty"` + Icon string `json:"icon,omitempty"` +} + +func passwordLogin() []stage { + return []stage{ + {Type: authtypes.LoginTypePassword}, + } +} + +func ssoLogin(cfg *config.ClientAPI) []stage { + var idps []identityProvider + for _, idp := range cfg.Login.SSO.Providers { + brand := idp.Brand + if brand == "" { + typ := idp.Type + if typ == "" { + typ = idp.ID + } + idpType := sso.GetIdentityProvider(sso.IdentityProviderType(typ)) + if idpType != nil { + brand = idpType.DefaultBrand() + } + } + idps = append(idps, identityProvider{ + ID: idp.ID, + Name: idp.Name, + Brand: brand, + Icon: idp.Icon, + }) + } + return []stage{ + { + Type: authtypes.LoginTypeSSO, + IdentityProviders: idps, + }, } - f.Flows = append(f.Flows, s) - return f } // Login implements GET and POST /login func Login( - req *http.Request, accountDB accounts.Database, userAPI userapi.UserInternalAPI, + req *http.Request, accountDB accounts.Database, userAPI uapi.UserInternalAPI, cfg *config.ClientAPI, ) util.JSONResponse { if req.Method == http.MethodGet { - // TODO: support other forms of login other than password, depending on config options + allFlows := passwordLogin() + if cfg.Login.SSO.Enabled { + allFlows = append(allFlows, ssoLogin(cfg)...) + } return util.JSONResponse{ Code: http.StatusOK, - JSON: passwordLogin(), + JSON: flows{Flows: allFlows}, } } else if req.Method == http.MethodPost { - typePassword := auth.LoginTypePassword{ - GetAccountByPassword: accountDB.GetAccountByPassword, - Config: cfg, - } - r := typePassword.Request() - resErr := httputil.UnmarshalJSONRequest(req, r) - if resErr != nil { - return *resErr - } - login, authErr := typePassword.Login(req.Context(), r) + login, cleanup, authErr := auth.LoginFromJSONReader(req.Context(), req.Body, accountDB, userAPI, cfg) if authErr != nil { return *authErr } // make a device/access token - return completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent()) + authzErr := completeAuth(req.Context(), cfg.Matrix.ServerName, userAPI, login, req.RemoteAddr, req.UserAgent()) + cleanup(req.Context(), &authzErr) + return authzErr } + return util.JSONResponse{ Code: http.StatusMethodNotAllowed, JSON: jsonerror.NotFound("Bad method"), @@ -88,7 +120,7 @@ func Login( } func completeAuth( - ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI userapi.UserInternalAPI, login *auth.Login, + ctx context.Context, serverName gomatrixserverlib.ServerName, userAPI uapi.UserInternalAPI, login *auth.Login, ipAddr, userAgent string, ) util.JSONResponse { token, err := auth.GenerateAccessToken() @@ -103,8 +135,8 @@ func completeAuth( return jsonerror.InternalServerError() } - var performRes userapi.PerformDeviceCreationResponse - err = userAPI.PerformDeviceCreation(ctx, &userapi.PerformDeviceCreationRequest{ + var performRes uapi.PerformDeviceCreationResponse + err = userAPI.PerformDeviceCreation(ctx, &uapi.PerformDeviceCreationRequest{ DeviceDisplayName: login.InitialDisplayName, DeviceID: login.DeviceID, AccessToken: token, diff --git a/clientapi/routing/routing.go b/clientapi/routing/routing.go index e5e9c7bd8f..90f60f98b9 100644 --- a/clientapi/routing/routing.go +++ b/clientapi/routing/routing.go @@ -64,7 +64,7 @@ func Setup( mscCfg *config.MSCs, ) { rateLimits := newRateLimits(&cfg.RateLimiting) - userInteractiveAuth := auth.NewUserInteractive(accountDB.GetAccountByPassword, cfg) + userInteractiveAuth := auth.NewUserInteractive(accountDB, cfg) unstableFeatures := map[string]bool{ "org.matrix.e2e_cross_signing": true, @@ -504,6 +504,25 @@ func Setup( }), ).Methods(http.MethodGet, http.MethodPost, http.MethodOptions) + r0mux.Handle("/login/sso/callback", + httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { + return SSOCallback(req, accountDB, userAPI, cfg) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + r0mux.Handle("/login/sso/redirect", + httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { + return SSORedirect(req, "", cfg) + }), + ).Methods(http.MethodGet, http.MethodOptions) + + r0mux.Handle("/login/sso/redirect/{idpID}", + httputil.MakeExternalAPI("login", func(req *http.Request) util.JSONResponse { + vars := mux.Vars(req) + return SSORedirect(req, vars["idpID"], cfg) + }), + ).Methods(http.MethodGet, http.MethodOptions) + r0mux.Handle("/auth/{authType}/fallback/web", httputil.MakeHTMLAPI("auth_fallback", func(w http.ResponseWriter, req *http.Request) *util.JSONResponse { vars := mux.Vars(req) diff --git a/clientapi/routing/sso.go b/clientapi/routing/sso.go new file mode 100644 index 0000000000..178c116e27 --- /dev/null +++ b/clientapi/routing/sso.go @@ -0,0 +1,259 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package routing + +import ( + "context" + "encoding/base64" + "fmt" + "net/http" + "net/url" + "strings" + "time" + + "github.com/matrix-org/dendrite/clientapi/auth" + "github.com/matrix-org/dendrite/clientapi/auth/sso" + "github.com/matrix-org/dendrite/clientapi/jsonerror" + "github.com/matrix-org/dendrite/clientapi/userutil" + "github.com/matrix-org/dendrite/setup/config" + uapi "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +// SSORedirect implements /login/sso/redirect +// https://matrix.org/docs/spec/client_server/r0.6.1#get-matrix-client-r0-login-sso-redirect +func SSORedirect( + req *http.Request, + idpID string, + cfg *config.ClientAPI, +) util.JSONResponse { + if !cfg.Login.SSO.Enabled { + return util.JSONResponse{ + Code: http.StatusNotImplemented, + JSON: jsonerror.NotFound("authentication method disabled"), + } + } + + redirectURL := req.URL.Query().Get("redirectUrl") + if redirectURL == "" { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MissingArgument("redirectUrl parameter missing"), + } + } + _, err := url.Parse(redirectURL) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("Invalid redirectURL: " + err.Error()), + } + } + + if idpID == "" { + // Check configuration if the client didn't provide an ID. + idpID = cfg.Login.SSO.DefaultProviderID + } + if idpID == "" && len(cfg.Login.SSO.Providers) > 0 { + // Fall back to the first provider. If there are no providers, getProvider("") will fail. + idpID = cfg.Login.SSO.Providers[0].ID + } + idpCfg, idpType := getProvider(cfg, idpID) + if idpType == nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("unknown identity provider"), + } + } + + idpReq := &sso.IdentityProviderRequest{ + System: idpCfg, + CallbackURL: req.URL.ResolveReference(&url.URL{Path: "../callback", RawQuery: url.Values{"provider": []string{idpID}}.Encode()}).String(), + DendriteNonce: formatNonce(redirectURL), + } + u, err := idpType.AuthorizationURL(req.Context(), idpReq) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: err, + } + } + + resp := util.RedirectResponse(u) + resp.Headers["Set-Cookie"] = (&http.Cookie{ + Name: "oidc_nonce", + Value: idpReq.DendriteNonce, + Expires: time.Now().Add(10 * time.Minute), + Secure: true, + SameSite: http.SameSiteStrictMode, + }).String() + return resp +} + +func SSOCallback( + req *http.Request, + accountDB auth.AccountDatabase, + userAPI auth.UserInternalAPIForLogin, + cfg *config.ClientAPI, +) util.JSONResponse { + ctx := req.Context() + + query := req.URL.Query() + idpID := query.Get("provider") + if idpID == "" { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MissingArgument("provider parameter missing"), + } + } + idpCfg, idpType := getProvider(cfg, idpID) + if idpType == nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.InvalidArgumentValue("unknown identity provider"), + } + } + + nonce, err := req.Cookie("oidc_nonce") + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: jsonerror.MissingArgument("no nonce cookie: " + err.Error()), + } + } + finalRedirectURL, err := parseNonce(nonce.Value) + if err != nil { + return util.JSONResponse{ + Code: http.StatusBadRequest, + JSON: err, + } + } + + idpReq := &sso.IdentityProviderRequest{ + System: idpCfg, + CallbackURL: (&url.URL{Scheme: req.URL.Scheme, Host: req.URL.Host, Path: req.URL.Path, RawQuery: url.Values{"provider": []string{idpID}}.Encode()}).String(), + DendriteNonce: nonce.Value, + } + result, err := idpType.ProcessCallback(ctx, idpReq, query) + if err != nil { + return util.JSONResponse{ + Code: http.StatusInternalServerError, + JSON: err, + } + } + + if result.Identifier == nil { + // Not authenticated yet. + return util.RedirectResponse(result.RedirectURL) + } + + id, err := verifyUserIdentifier(ctx, accountDB, result.Identifier) + if err != nil { + util.GetLogger(ctx).WithError(err).WithField("identifier", result.Identifier.String()).Error("failed to find user") + return util.JSONResponse{ + Code: http.StatusUnauthorized, + JSON: jsonerror.Forbidden("ID not associated with a local account"), + } + } + + token, err := createLoginToken(ctx, userAPI, id) + if err != nil { + util.GetLogger(ctx).WithError(err).Errorf("PerformLoginTokenCreation failed") + return jsonerror.InternalServerError() + } + + rquery := finalRedirectURL.Query() + rquery.Set("loginToken", token.Token) + resp := util.RedirectResponse(finalRedirectURL.ResolveReference(&url.URL{RawQuery: rquery.Encode()}).String()) + resp.Headers["Set-Cookie"] = (&http.Cookie{ + Name: "oidc_nonce", + Value: "", + MaxAge: -1, + Secure: true, + }).String() + return resp +} + +func getProvider(cfg *config.ClientAPI, id string) (*config.IdentityProvider, sso.IdentityProvider) { + for _, idp := range cfg.Login.SSO.Providers { + if idp.ID == id { + switch sso.IdentityProviderType(id) { + case sso.TypeGitHub: + return &idp, sso.GitHubIdentityProvider + default: + return nil, nil + } + } + } + return nil, nil +} + +func formatNonce(redirectURL string) string { + return util.RandomString(16) + "." + base64.RawURLEncoding.EncodeToString([]byte(redirectURL)) +} + +func parseNonce(s string) (redirectURL *url.URL, _ error) { + if s == "" { + return nil, jsonerror.MissingArgument("empty OIDC nonce cookie") + } + + ss := strings.Split(s, ".") + if len(ss) < 2 { + return nil, jsonerror.InvalidArgumentValue("malformed OIDC nonce cookie") + } + + urlbs, err := base64.RawURLEncoding.DecodeString(ss[1]) + if err != nil { + return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in OIDC nonce cookie") + } + u, err := url.Parse(string(urlbs)) + if err != nil { + return nil, jsonerror.InvalidArgumentValue("invalid redirect URL in OIDC nonce cookie: " + err.Error()) + } + + return u, nil +} + +func verifyUserIdentifier(ctx context.Context, accountDB auth.AccountDatabase, id userutil.Identifier) (*userutil.UserIdentifier, error) { + var localpart string + switch iid := id.(type) { + case *userutil.ThirdPartyIdentifier: + var err error + localpart, err = accountDB.GetLocalpartForThreePID(ctx, iid.Address, string(iid.Medium)) + if err != nil { + return nil, err + } + + case *userutil.UserIdentifier: + localpart = iid.UserID + + default: + return nil, fmt.Errorf("unsupported ID type: %T", id) + } + + acc, err := accountDB.GetAccountByLocalpart(ctx, localpart) + if err != nil { + return nil, err + } + return &userutil.UserIdentifier{UserID: acc.UserID}, nil +} + +func createLoginToken(ctx context.Context, userAPI auth.UserInternalAPIForLogin, id *userutil.UserIdentifier) (*uapi.LoginTokenMetadata, error) { + req := uapi.PerformLoginTokenCreationRequest{Data: uapi.LoginTokenData{UserID: id.UserID}} + var resp uapi.PerformLoginTokenCreationResponse + if err := userAPI.PerformLoginTokenCreation(ctx, &req, &resp); err != nil { + return nil, err + } + return &resp.Metadata, nil +} diff --git a/clientapi/userutil/identifier.go b/clientapi/userutil/identifier.go new file mode 100644 index 0000000000..46c8a0f020 --- /dev/null +++ b/clientapi/userutil/identifier.go @@ -0,0 +1,153 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package userutil + +import ( + "bytes" + "encoding/json" + "errors" +) + +// An Identifier identifies a user. There are many kinds, and this is +// the common interface for them. +// +// If you need to handle an identifier as JSON, use the AnyIdentifier wrapper. +// Passing around identifiers in code, the raw Identifier is enough. +// +// See https://matrix.org/docs/spec/client_server/r0.6.1#identifier-types +type Identifier interface { + // IdentifierType returns the identifier type, like "m.id.user". + IdentifierType() IdentifierType + + // String returns a debug-output string representation. The format + // is not specified. + String() string +} + +// A UserIdentifier contains an MXID. It may be only the local part. +type UserIdentifier struct { + UserID string `json:"user"` +} + +func (i *UserIdentifier) IdentifierType() IdentifierType { return IdentifierUser } +func (i *UserIdentifier) String() string { return i.UserID } + +// A ThirdPartyIdentifier references an identifier in another system. +type ThirdPartyIdentifier struct { + // Medium is normally MediumEmail. + Medium Medium `json:"medium"` + + // Address is the medium-specific identifier. + Address string `json:"address"` +} + +func (i *ThirdPartyIdentifier) IdentifierType() IdentifierType { return IdentifierThirdParty } +func (i *ThirdPartyIdentifier) String() string { return string(i.Medium) + ":" + i.Address } + +// A PhoneIdentifier references a phone number. +type PhoneIdentifier struct { + // Country is a ISO-3166-1 alpha-2 country code. + Country string `json:"country"` + + // PhoneNumber is a country-specific phone number, as it would be dialled from. + PhoneNumber string `json:"phone"` +} + +func (i *PhoneIdentifier) IdentifierType() IdentifierType { return IdentifierPhone } +func (i *PhoneIdentifier) String() string { return i.Country + ":" + i.PhoneNumber } + +// UnknownIdentifier is the catch-all for identifiers this code doesn't know about. +// It simply stores raw JSON. +type UnknownIdentifier struct { + json.RawMessage + Type IdentifierType +} + +func (i *UnknownIdentifier) IdentifierType() IdentifierType { return i.Type } +func (i *UnknownIdentifier) String() string { return "unknown/" + string(i.Type) } + +// AnyIdentifier is a wrapper that allows marshalling and unmarshalling the various +// types of identifiers to/from JSON. Always use this in data types that will be +// used in JSON manipulation. +type AnyIdentifier struct { + Identifier +} + +func (i AnyIdentifier) MarshalJSON() ([]byte, error) { + v := struct { + *UserIdentifier + *ThirdPartyIdentifier + *PhoneIdentifier + Type IdentifierType `json:"type"` + }{ + Type: i.Identifier.IdentifierType(), + } + switch iid := i.Identifier.(type) { + case *UserIdentifier: + v.UserIdentifier = iid + case *ThirdPartyIdentifier: + v.ThirdPartyIdentifier = iid + case *PhoneIdentifier: + v.PhoneIdentifier = iid + case *UnknownIdentifier: + return iid.RawMessage, nil + } + return json.Marshal(v) +} + +func (i *AnyIdentifier) UnmarshalJSON(bs []byte) error { + var hdr struct { + Type IdentifierType `json:"type"` + } + if err := json.Unmarshal(bs, &hdr); err != nil { + return err + } + switch hdr.Type { + case IdentifierUser: + var ui UserIdentifier + i.Identifier = &ui + return json.Unmarshal(bs, &ui) + case IdentifierThirdParty: + var tpi ThirdPartyIdentifier + i.Identifier = &tpi + return json.Unmarshal(bs, &tpi) + case IdentifierPhone: + var pi PhoneIdentifier + i.Identifier = &pi + return json.Unmarshal(bs, &pi) + case "": + return errors.New("missing identifier type") + default: + i.Identifier = &UnknownIdentifier{RawMessage: json.RawMessage(bytes.TrimSpace(bs)), Type: hdr.Type} + return nil + } +} + +// IdentifierType describes the type of identifier. +type IdentifierType string + +const ( + IdentifierUser IdentifierType = "m.id.user" + IdentifierThirdParty IdentifierType = "m.id.thirdparty" + IdentifierPhone IdentifierType = "m.id.phone" +) + +// Medium describes the interpretation of a third-party identifier. +type Medium string + +const ( + // MediumEmail signifies that the address is an email address. + MediumEmail Medium = "email" +) diff --git a/clientapi/userutil/identifier_test.go b/clientapi/userutil/identifier_test.go new file mode 100644 index 0000000000..cd02524c3b --- /dev/null +++ b/clientapi/userutil/identifier_test.go @@ -0,0 +1,61 @@ +package userutil + +import ( + "encoding/json" + "reflect" + "testing" +) + +func TestAnyIdentifierJSON(t *testing.T) { + tsts := []struct { + Name string + JSON string + Want Identifier + }{ + {Name: "empty", JSON: `{}`}, + {Name: "user", JSON: `{"type":"m.id.user","user":"auser"}`, Want: &UserIdentifier{UserID: "auser"}}, + {Name: "thirdparty", JSON: `{"type":"m.id.thirdparty","medium":"email","address":"auser@example.com"}`, Want: &ThirdPartyIdentifier{Medium: "email", Address: "auser@example.com"}}, + {Name: "phone", JSON: `{"type":"m.id.phone","country":"GB","phone":"123456789"}`, Want: &PhoneIdentifier{Country: "GB", PhoneNumber: "123456789"}}, + // This test is a little fragile since it compares the output of json.Marshal. + {Name: "unknown", JSON: `{"type":"other"}`, Want: &UnknownIdentifier{Type: "other", RawMessage: json.RawMessage(`{"type":"other"}`)}}, + } + for _, tst := range tsts { + t.Run("Unmarshal/"+tst.Name, func(t *testing.T) { + var got AnyIdentifier + if err := json.Unmarshal([]byte(tst.JSON), &got); err != nil { + if tst.Want == nil { + return + } + t.Fatalf("Unmarshal failed: %v", err) + } + + if !reflect.DeepEqual(got.Identifier, tst.Want) { + t.Errorf("got %+v, want %+v", got.Identifier, tst.Want) + } + }) + + if tst.Want == nil { + continue + } + t.Run("Marshal/"+tst.Name, func(t *testing.T) { + id := AnyIdentifier{Identifier: tst.Want} + bs, err := json.Marshal(id) + if err != nil { + t.Fatalf("Marshal failed: %v", err) + } + t.Logf("Marshalled JSON: %q", string(bs)) + + var got AnyIdentifier + if err := json.Unmarshal(bs, &got); err != nil { + if tst.Want == nil { + return + } + t.Fatalf("Unmarshal failed: %v", err) + } + + if !reflect.DeepEqual(got.Identifier, tst.Want) { + t.Errorf("got %+v, want %+v", got.Identifier, tst.Want) + } + }) + } +} diff --git a/go.mod b/go.mod index e48d45361c..6ffc811005 100644 --- a/go.mod +++ b/go.mod @@ -33,7 +33,7 @@ require ( github.com/matrix-org/gomatrix v0.0.0-20210324163249-be2af5ef2e16 github.com/matrix-org/gomatrixserverlib v0.0.0-20210817115641-f9416ac1a723 github.com/matrix-org/naffka v0.0.0-20210623111924-14ff508b58e0 - github.com/matrix-org/pinecone v0.0.0-20210910134625-4ec11c22f2c8 + github.com/matrix-org/pinecone v0.0.0-20210920152116-4f07afaed998 github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 github.com/matryer/is v1.4.0 github.com/mattn/go-sqlite3 v1.14.8 diff --git a/go.sum b/go.sum index 7a4c082336..b5f07cfcb9 100644 --- a/go.sum +++ b/go.sum @@ -998,8 +998,8 @@ github.com/matrix-org/gomatrixserverlib v0.0.0-20210817115641-f9416ac1a723 h1:b8 github.com/matrix-org/gomatrixserverlib v0.0.0-20210817115641-f9416ac1a723/go.mod h1:JsAzE1Ll3+gDWS9JSUHPJiiyAksvOOnGWF2nXdg4ZzU= github.com/matrix-org/naffka v0.0.0-20210623111924-14ff508b58e0 h1:HZCzy4oVzz55e+cOMiX/JtSF2UOY1evBl2raaE7ACcU= github.com/matrix-org/naffka v0.0.0-20210623111924-14ff508b58e0/go.mod h1:sjyPyRxKM5uw1nD2cJ6O2OxI6GOqyVBfNXqKjBZTBZE= -github.com/matrix-org/pinecone v0.0.0-20210910134625-4ec11c22f2c8 h1:oE7rDEoz3J3PVkugf10XxjPnGO9PGZOSMwNazfI/0dM= -github.com/matrix-org/pinecone v0.0.0-20210910134625-4ec11c22f2c8/go.mod h1:CVlrvs1R5iz7Omy2GqAjJJKbACn07GZgUq1Gli18FYE= +github.com/matrix-org/pinecone v0.0.0-20210920152116-4f07afaed998 h1:Jdxog8Gh1/8IOhkeEub4LeFhaYWNJmOpG16F+eas6r4= +github.com/matrix-org/pinecone v0.0.0-20210920152116-4f07afaed998/go.mod h1:CVlrvs1R5iz7Omy2GqAjJJKbACn07GZgUq1Gli18FYE= github.com/matrix-org/util v0.0.0-20190711121626-527ce5ddefc7/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4 h1:eCEHXWDv9Rm335MSuB49mFUK44bwZPFSDde3ORE3syk= github.com/matrix-org/util v0.0.0-20200807132607-55161520e1d4/go.mod h1:vVQlW/emklohkZnOPwD3LrZUBqdfsbiyO3p1lNV8F6U= diff --git a/setup/base.go b/setup/base.go index 6901205c3d..82004ee4d6 100644 --- a/setup/base.go +++ b/setup/base.go @@ -91,10 +91,13 @@ type BaseDendrite struct { // KafkaProducer sarama.SyncProducer } -const HTTPServerTimeout = time.Minute * 5 -const HTTPClientTimeout = time.Second * 30 +const ( + HTTPServerTimeout = time.Minute * 5 + HTTPServerRequestTimeout = HTTPServerTimeout + HTTPClientTimeout = time.Second * 30 -const NoListener = "" + NoListener = "" +) // NewBaseDendrite creates a new instance to be used by a component. // The componentName is used for logging purposes, and should be a friendly name @@ -355,6 +358,7 @@ func (b *BaseDendrite) SetupAndServeHTTP( externalAddr, _ := externalHTTPAddr.Address() externalRouter := mux.NewRouter().SkipClean(true).UseEncodedPath() + externalRouter.Use(timeoutMiddleware) internalRouter := externalRouter externalServ := &http.Server{ @@ -475,6 +479,15 @@ func (b *BaseDendrite) SetupAndServeHTTP( logrus.Infof("Stopped HTTP listeners") } +// timeoutMiddleware is a Gorilla middleware that adds a timeout to all request contexts. +func timeoutMiddleware(next http.Handler) http.Handler { + return http.HandlerFunc(func(w http.ResponseWriter, r *http.Request) { + ctx, cancel := context.WithTimeout(r.Context(), HTTPServerRequestTimeout) + defer cancel() + next.ServeHTTP(w, r.WithContext(ctx)) + }) +} + func (b *BaseDendrite) WaitForShutdown() { sigs := make(chan os.Signal, 1) signal.Notify(sigs, syscall.SIGINT, syscall.SIGTERM) diff --git a/setup/config/config_clientapi.go b/setup/config/config_clientapi.go index c7cb9c33e0..cad284e20f 100644 --- a/setup/config/config_clientapi.go +++ b/setup/config/config_clientapi.go @@ -32,6 +32,8 @@ type ClientAPI struct { // was successful RecaptchaSiteVerifyAPI string `yaml:"recaptcha_siteverify_api"` + Login Login `yaml:"login"` + // TURN options TURN TURN `yaml:"turn"` @@ -53,6 +55,7 @@ func (c *ClientAPI) Defaults() { c.RecaptchaSiteVerifyAPI = "" c.RegistrationDisabled = false c.RateLimiting.Defaults() + c.Login.SSO.Enabled = false } func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { @@ -66,10 +69,130 @@ func (c *ClientAPI) Verify(configErrs *ConfigErrors, isMonolith bool) { checkNotEmpty(configErrs, "client_api.recaptcha_private_key", string(c.RecaptchaPrivateKey)) checkNotEmpty(configErrs, "client_api.recaptcha_siteverify_api", string(c.RecaptchaSiteVerifyAPI)) } + c.Login.Verify(configErrs) c.TURN.Verify(configErrs) c.RateLimiting.Verify(configErrs) } +type Login struct { + SSO SSO `yaml:"sso"` +} + +func (l *Login) Verify(configErrs *ConfigErrors) { + l.SSO.Verify(configErrs) +} + +type SSO struct { + // Enabled determines whether SSO should be allowed. + Enabled bool `yaml:"enabled"` + + // Providers list the identity providers this server is capable of confirming an + // identity with. + Providers []IdentityProvider `yaml:"providers"` + + // DefaultProviderID is the provider to use when the client doesn't indicate one. + // This is legacy support. If empty, the first provider listed is used. + DefaultProviderID string `yaml:"default_provider"` +} + +func (sso *SSO) Verify(configErrs *ConfigErrors) { + var foundDefaultProvider bool + seenPIDs := make(map[string]bool, len(sso.Providers)) + for _, p := range sso.Providers { + p.Verify(configErrs) + if p.ID == sso.DefaultProviderID { + foundDefaultProvider = true + } + if seenPIDs[p.ID] { + configErrs.Add(fmt.Sprintf("duplicate identity provider for config key %q: %s", "client_api.sso.providers", p.ID)) + } + seenPIDs[p.ID] = true + } + if sso.DefaultProviderID != "" && !foundDefaultProvider { + configErrs.Add(fmt.Sprintf("identity provider ID not found for config key %q: %s", "client_api.sso.default_provider", sso.DefaultProviderID)) + } + + if sso.Enabled { + if len(sso.Providers) == 0 { + configErrs.Add(fmt.Sprintf("empty list for config key %q", "client_api.sso.providers")) + } + } +} + +// See https://github.com/matrix-org/matrix-doc/blob/old_master/informal/idp-brands.md. +type IdentityProvider struct { + // ID is the unique identifier of this IdP. We use the brand identifiers as provider + // identifiers for simplicity. + ID string `yaml:"id"` + + // Name is a human-friendly name of the provider. + Name string `yaml:"name"` + + // Brand is a hint on how to display the IdP to the user. If this is empty, a default + // based on the type is used. + Brand string `yaml:"brand"` + + // Icon is an MXC URI describing how to display the IdP to the user. Prefer using `brand`. + Icon string `yaml:"icon"` + + // Type describes how this provider is implemented. It must match "github". If this is + // empty, the ID is used, which means there is a weak expectation that ID is also a + // valid type, unless you have a complicated setup. + Type string `yaml:"type"` + + // OIDC contains settings for providers based on OpenID Connect (OAuth 2). + OIDC struct { + ClientID string `yaml:"client_id"` + ClientSecret string `yaml:"client_secret"` + } `yaml:"oidc"` +} + +func (idp *IdentityProvider) Verify(configErrs *ConfigErrors) { + checkNotEmpty(configErrs, "client_api.sso.providers.id", idp.ID) + if !checkIdentityProviderBrand(idp.ID) { + configErrs.Add(fmt.Sprintf("unrecognized ID config key %q: %s", "client_api.sso.providers", idp.ID)) + } + checkNotEmpty(configErrs, "client_api.sso.providers.name", idp.Name) + if idp.Brand != "" && !checkIdentityProviderBrand(idp.Brand) { + configErrs.Add(fmt.Sprintf("unrecognized brand in identity provider %q for config key %q: %s", idp.ID, "client_api.sso.providers", idp.Brand)) + } + if idp.Icon != "" { + checkURL(configErrs, "client_api.sso.providers.icon", idp.Icon) + } + typ := idp.Type + if idp.Type == "" { + typ = idp.ID + } + + switch typ { + case "github": + checkNotEmpty(configErrs, "client_api.sso.providers.oidc.client_id", idp.OIDC.ClientID) + checkNotEmpty(configErrs, "client_api.sso.providers.oidc.client_secret", idp.OIDC.ClientSecret) + + default: + configErrs.Add(fmt.Sprintf("unrecognized type in identity provider %q for config key %q: %s", idp.ID, "client_api.sso.providers", typ)) + } +} + +// See https://github.com/matrix-org/matrix-doc/blob/old_master/informal/idp-brands.md. +func checkIdentityProviderBrand(s string) bool { + switch s { + case SSOBrandApple, SSOBrandFacebook, SSOBrandGitHub, SSOBrandGitLab, SSOBrandGoogle, SSOBrandTwitter: + return true + default: + return false + } +} + +const ( + SSOBrandApple = "apple" + SSOBrandFacebook = "facebook" + SSOBrandGitHub = "github" + SSOBrandGitLab = "gitlab" + SSOBrandGoogle = "google" + SSOBrandTwitter = "twitter" +) + type TURN struct { // TODO Guest Support // Whether or not guests can request TURN credentials diff --git a/sytest-blacklist b/sytest-blacklist index a0aba69c74..72c275a4a2 100644 --- a/sytest-blacklist +++ b/sytest-blacklist @@ -30,7 +30,7 @@ Outbound federation can backfill events Alias creators can delete canonical alias with no ops # Blacklisted because we need to implement v2 invite endpoints for room versions -# to be supported (currently fails with M_UNSUPPORTED_ROOM_VERSION) +# to be supported (currently fails with M_UNSUPPORTED_ROOM_VERSION) Inbound federation rejects invites which are not signed by the sender # Blacklisted because we don't support ignores yet @@ -77,3 +77,7 @@ Local device key changes appear in /keys/changes # we don't support groups Remove group category Remove group role + +# Broken +Device list doesn't change if remote server is down +If a device list update goes missing, the server resyncs on the next one diff --git a/sytest-whitelist b/sytest-whitelist index 824038c28f..6b33b37fff 100644 --- a/sytest-whitelist +++ b/sytest-whitelist @@ -147,10 +147,10 @@ Server correctly handles incoming m.device_list_update If remote user leaves room, changes device and rejoins we see update in sync If remote user leaves room, changes device and rejoins we see update in /keys/changes If remote user leaves room we no longer receive device updates -If a device list update goes missing, the server resyncs on the next one +#If a device list update goes missing, the server resyncs on the next one Server correctly resyncs when client query keys and there is no remote cache Server correctly resyncs when server leaves and rejoins a room -Device list doesn't change if remote server is down +#Device list doesn't change if remote server is down Can add account data Can add account data to room Can get account data without syncing @@ -259,6 +259,7 @@ Real non-joined users cannot room initalSync for non-world_readable rooms Push rules come down in an initial /sync Regular users can add and delete aliases in the default room configuration GET /r0/capabilities is not public +login types include SSO GET /joined_rooms lists newly-created room /joined_rooms returns only joined rooms Message history can be paginated over federation @@ -340,17 +341,17 @@ Existing members see new members' join events Inbound federation can receive events Inbound federation can receive redacted events Can logout current device -Can send a message directly to a device using PUT /sendToDevice -Can recv a device message using /sync -Can recv device messages until they are acknowledged -Device messages with the same txn_id are deduplicated -Device messages wake up /sync -Can recv device messages over federation -Device messages over federation wake up /sync -Can send messages with a wildcard device id -Can send messages with a wildcard device id to two devices -Wildcard device messages wake up /sync -Wildcard device messages over federation wake up /sync +Can send a message directly to a device using PUT /sendToDevice +Can recv a device message using /sync +Can recv device messages until they are acknowledged +Device messages with the same txn_id are deduplicated +Device messages wake up /sync +Can recv device messages over federation +Device messages over federation wake up /sync +Can send messages with a wildcard device id +Can send messages with a wildcard device id to two devices +Wildcard device messages wake up /sync +Wildcard device messages over federation wake up /sync Can send a to-device message to two users which both receive it using /sync User can create and send/receive messages in a room with version 6 local user can join room with version 6 @@ -478,7 +479,7 @@ Federation key API can act as a notary server via a GET request Inbound /make_join rejects attempts to join rooms where all users have left Inbound federation rejects invites which include invalid JSON for room version 6 Inbound federation rejects invite rejections which include invalid JSON for room version 6 -GET /capabilities is present and well formed for registered user +GET /capabilities is present and well formed for registered user m.room.history_visibility == "joined" allows/forbids appropriately for Guest users m.room.history_visibility == "joined" allows/forbids appropriately for Real users POST rejects invalid utf-8 in JSON diff --git a/userapi/api/api.go b/userapi/api/api.go index 75d06dd697..e80829e904 100644 --- a/userapi/api/api.go +++ b/userapi/api/api.go @@ -24,6 +24,8 @@ import ( // UserInternalAPI is the internal API for information about users and devices. type UserInternalAPI interface { + LoginTokenInternalAPI + InputAccountData(ctx context.Context, req *InputAccountDataRequest, res *InputAccountDataResponse) error PerformAccountCreation(ctx context.Context, req *PerformAccountCreationRequest, res *PerformAccountCreationResponse) error PerformPasswordUpdate(ctx context.Context, req *PerformPasswordUpdateRequest, res *PerformPasswordUpdateResponse) error diff --git a/userapi/api/api_logintoken.go b/userapi/api/api_logintoken.go new file mode 100644 index 0000000000..8c0a34dfc5 --- /dev/null +++ b/userapi/api/api_logintoken.go @@ -0,0 +1,68 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + "time" +) + +type LoginTokenInternalAPI interface { + // PerformLoginTokenCreation creates a new login token and associates it with the provided data. + PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error + + // PerformLoginTokenDeletion ensures the token doesn't exist. + PerformLoginTokenDeletion(ctx context.Context, req *PerformLoginTokenDeletionRequest, res *PerformLoginTokenDeletionResponse) error + + // QueryLoginToken returns the data associated with a login token. If + // the token is not valid, success is returned, but res.Data == nil. + QueryLoginToken(ctx context.Context, req *QueryLoginTokenRequest, res *QueryLoginTokenResponse) error +} + +// LoginTokenData is the data that can be retrieved given a login token. This is +// provided by the calling code. +type LoginTokenData struct { + // UserID is the full mxid of the user. + UserID string +} + +// LoginTokenMetadata contains metadata created and maintained by the User API. +type LoginTokenMetadata struct { + Token string + Expiration time.Time +} + +type PerformLoginTokenCreationRequest struct { + Data LoginTokenData +} + +type PerformLoginTokenCreationResponse struct { + Metadata LoginTokenMetadata +} + +type PerformLoginTokenDeletionRequest struct { + Token string +} + +type PerformLoginTokenDeletionResponse struct{} + +type QueryLoginTokenRequest struct { + Token string +} + +type QueryLoginTokenResponse struct { + // Data is nil if the token was invalid. + Data *LoginTokenData +} diff --git a/userapi/api/api_trace_logintoken.go b/userapi/api/api_trace_logintoken.go new file mode 100644 index 0000000000..e60dae5941 --- /dev/null +++ b/userapi/api/api_trace_logintoken.go @@ -0,0 +1,39 @@ +// Copyright 2021 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package api + +import ( + "context" + + "github.com/matrix-org/util" +) + +func (t *UserInternalAPITrace) PerformLoginTokenCreation(ctx context.Context, req *PerformLoginTokenCreationRequest, res *PerformLoginTokenCreationResponse) error { + err := t.Impl.PerformLoginTokenCreation(ctx, req, res) + util.GetLogger(ctx).Infof("PerformLoginTokenCreation req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *UserInternalAPITrace) PerformLoginTokenDeletion(ctx context.Context, req *PerformLoginTokenDeletionRequest, res *PerformLoginTokenDeletionResponse) error { + err := t.Impl.PerformLoginTokenDeletion(ctx, req, res) + util.GetLogger(ctx).Infof("PerformLoginTokenDeletion req=%+v res=%+v", js(req), js(res)) + return err +} + +func (t *UserInternalAPITrace) QueryLoginToken(ctx context.Context, req *QueryLoginTokenRequest, res *QueryLoginTokenResponse) error { + err := t.Impl.QueryLoginToken(ctx, req, res) + util.GetLogger(ctx).Infof("QueryLoginToken req=%+v res=%+v", js(req), js(res)) + return err +} diff --git a/userapi/internal/api_logintoken.go b/userapi/internal/api_logintoken.go new file mode 100644 index 0000000000..3db85094e6 --- /dev/null +++ b/userapi/internal/api_logintoken.go @@ -0,0 +1,55 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package internal + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +// PerformLoginTokenCreation creates a new login token and associates it with the provided data. +func (a *UserInternalAPI) PerformLoginTokenCreation(ctx context.Context, req *api.PerformLoginTokenCreationRequest, res *api.PerformLoginTokenCreationResponse) error { + util.GetLogger(ctx).WithField("user_id", req.Data.UserID).Info("PerformLoginTokenCreation") + tokenMeta, err := a.DeviceDB.CreateLoginToken(ctx, &req.Data) + if err != nil { + return err + } + res.Metadata = *tokenMeta + return nil +} + +// PerformLoginTokenDeletion ensures the token doesn't exist. +func (a *UserInternalAPI) PerformLoginTokenDeletion(ctx context.Context, req *api.PerformLoginTokenDeletionRequest, res *api.PerformLoginTokenDeletionResponse) error { + util.GetLogger(ctx).Info("PerformLoginTokenDeletion") + return a.DeviceDB.RemoveLoginToken(ctx, req.Token) +} + +// QueryLoginToken returns the data associated with a login token. If +// the token is not valid, success is returned, but res.Data == nil. +func (a *UserInternalAPI) QueryLoginToken(ctx context.Context, req *api.QueryLoginTokenRequest, res *api.QueryLoginTokenResponse) error { + tokenData, err := a.DeviceDB.GetLoginTokenByToken(ctx, req.Token) + if err != nil { + res.Data = nil + if err == sql.ErrNoRows { + return nil + } + return err + } + res.Data = tokenData + return nil +} diff --git a/userapi/inthttp/client_logintoken.go b/userapi/inthttp/client_logintoken.go new file mode 100644 index 0000000000..a68d32c1d7 --- /dev/null +++ b/userapi/inthttp/client_logintoken.go @@ -0,0 +1,65 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inthttp + +import ( + "context" + + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/opentracing/opentracing-go" +) + +const ( + PerformLoginTokenCreationPath = "/userapi/performLoginTokenCreation" + PerformLoginTokenDeletionPath = "/userapi/performLoginTokenDeletion" + QueryLoginTokenPath = "/userapi/queryLoginToken" +) + +func (h *httpUserInternalAPI) PerformLoginTokenCreation( + ctx context.Context, + request *api.PerformLoginTokenCreationRequest, + response *api.PerformLoginTokenCreationResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenCreation") + defer span.Finish() + + apiURL := h.apiURL + PerformLoginTokenCreationPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) PerformLoginTokenDeletion( + ctx context.Context, + request *api.PerformLoginTokenDeletionRequest, + response *api.PerformLoginTokenDeletionResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "PerformLoginTokenDeletion") + defer span.Finish() + + apiURL := h.apiURL + PerformLoginTokenDeletionPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} + +func (h *httpUserInternalAPI) QueryLoginToken( + ctx context.Context, + request *api.QueryLoginTokenRequest, + response *api.QueryLoginTokenResponse, +) error { + span, ctx := opentracing.StartSpanFromContext(ctx, "QueryLoginToken") + defer span.Finish() + + apiURL := h.apiURL + QueryLoginTokenPath + return httputil.PostJSON(ctx, span, h.httpClient, apiURL, request, response) +} diff --git a/userapi/inthttp/server.go b/userapi/inthttp/server.go index 90db6bc4ec..26b883bff0 100644 --- a/userapi/inthttp/server.go +++ b/userapi/inthttp/server.go @@ -26,6 +26,8 @@ import ( // nolint: gocyclo func AddRoutes(internalAPIMux *mux.Router, s api.UserInternalAPI) { + addRoutesLoginToken(internalAPIMux, s) + internalAPIMux.Handle(PerformAccountCreationPath, httputil.MakeInternalAPI("performAccountCreation", func(req *http.Request) util.JSONResponse { request := api.PerformAccountCreationRequest{} diff --git a/userapi/inthttp/server_logintoken.go b/userapi/inthttp/server_logintoken.go new file mode 100644 index 0000000000..5a4ffe3da4 --- /dev/null +++ b/userapi/inthttp/server_logintoken.go @@ -0,0 +1,68 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package inthttp + +import ( + "encoding/json" + "net/http" + + "github.com/gorilla/mux" + "github.com/matrix-org/dendrite/internal/httputil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +// addRoutesLoginToken adds routes for all login token API calls. +func addRoutesLoginToken(internalAPIMux *mux.Router, s api.UserInternalAPI) { + internalAPIMux.Handle(PerformLoginTokenCreationPath, + httputil.MakeInternalAPI("performLoginTokenCreation", func(req *http.Request) util.JSONResponse { + request := api.PerformLoginTokenCreationRequest{} + response := api.PerformLoginTokenCreationResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformLoginTokenCreation(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(PerformLoginTokenDeletionPath, + httputil.MakeInternalAPI("performLoginTokenDeletion", func(req *http.Request) util.JSONResponse { + request := api.PerformLoginTokenDeletionRequest{} + response := api.PerformLoginTokenDeletionResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.PerformLoginTokenDeletion(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) + internalAPIMux.Handle(QueryLoginTokenPath, + httputil.MakeInternalAPI("queryLoginToken", func(req *http.Request) util.JSONResponse { + request := api.QueryLoginTokenRequest{} + response := api.QueryLoginTokenResponse{} + if err := json.NewDecoder(req.Body).Decode(&request); err != nil { + return util.MessageResponse(http.StatusBadRequest, err.Error()) + } + if err := s.QueryLoginToken(req.Context(), &request, &response); err != nil { + return util.ErrorResponse(err) + } + return util.JSONResponse{Code: http.StatusOK, JSON: &response} + }), + ) +} diff --git a/userapi/storage/devices/interface.go b/userapi/storage/devices/interface.go index 95fe99f338..cb7602f0a8 100644 --- a/userapi/storage/devices/interface.go +++ b/userapi/storage/devices/interface.go @@ -38,4 +38,15 @@ type Database interface { RemoveDevices(ctx context.Context, localpart string, devices []string) error // RemoveAllDevices deleted all devices for this user. Returns the devices deleted. RemoveAllDevices(ctx context.Context, localpart, exceptDeviceID string) (devices []api.Device, err error) + + // CreateLoginToken generates a token, stores and returns it. The lifetime is + // determined by the loginTokenLifetime given to the Database constructor. + CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) + + // RemoveLoginToken removes the named token (and may clean up other expired tokens). + RemoveLoginToken(ctx context.Context, token string) error + + // GetLoginTokenByToken returns the data associated with the given token. + // May return sql.ErrNoRows. + GetLoginTokenByToken(ctx context.Context, token string) (*api.LoginTokenData, error) } diff --git a/userapi/storage/devices/postgres/logintoken_table.go b/userapi/storage/devices/postgres/logintoken_table.go new file mode 100644 index 0000000000..2bdca539aa --- /dev/null +++ b/userapi/storage/devices/postgres/logintoken_table.go @@ -0,0 +1,92 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package postgres + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +type loginTokenStatements struct { + insertStmt *sql.Stmt + deleteStmt *sql.Stmt + selectStmt *sql.Stmt +} + +// execSchema ensures tables and indices exist. +func (s *loginTokenStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(` +CREATE TABLE IF NOT EXISTS login_tokens ( + -- The random value of the token issued to a user + token TEXT NOT NULL PRIMARY KEY, + -- When the token expires + token_expires_at TIMESTAMP NOT NULL, + + -- The mxid for this account + user_id TEXT NOT NULL +); + +-- This index allows efficient garbage collection of expired tokens. +CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); +`) + return err +} + +// prepare runs statement preparation. +func (s *loginTokenStatements) prepare(db *sql.DB) error { + return sqlutil.StatementList{ + {&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, + {&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= CURRENT_TIMESTAMP"}, + {&s.selectStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > CURRENT_TIMESTAMP"}, + }.Prepare(db) +} + +// insert adds an already generated token to the database. +func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { + stmt := sqlutil.TxStmt(txn, s.insertStmt) + _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) + return err +} + +// deleteByToken removes the named token. +// +// As a simple way to garbage-collect stale tokens, we also remove all expired tokens. +// The login_tokens_expiration_idx index should make that efficient. +func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error { + stmt := sqlutil.TxStmt(txn, s.deleteStmt) + res, err := stmt.ExecContext(ctx, token) + if err != nil { + return err + } + if n, err := res.RowsAffected(); err == nil && n > 1 { + util.GetLogger(ctx).WithField("num_deleted", n).Infof("Deleted %d login tokens (%d likely garbage collected)", n, n-1) + } + return nil +} + +// selectByToken returns the data associated with the given token. May return sql.ErrNoRows. +func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + var data api.LoginTokenData + err := s.selectStmt.QueryRowContext(ctx, token).Scan(&data.UserID) + if err != nil { + return nil, err + } + + return &data, nil +} diff --git a/userapi/storage/devices/postgres/storage.go b/userapi/storage/devices/postgres/storage.go index 4852343316..8817880660 100644 --- a/userapi/storage/devices/postgres/storage.go +++ b/userapi/storage/devices/postgres/storage.go @@ -19,6 +19,7 @@ import ( "crypto/rand" "database/sql" "encoding/base64" + "time" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" @@ -27,28 +28,38 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -// The length of generated device IDs -var deviceIDByteLength = 6 +const ( + // The length of generated device IDs + deviceIDByteLength = 6 + loginTokenByteLength = 32 +) // Database represents a device database. type Database struct { - db *sql.DB - devices devicesStatements + db *sql.DB + devices devicesStatements + loginTokens loginTokenStatements + loginTokenLifetime time.Duration } // NewDatabase creates a new device database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) { db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err } - d := devicesStatements{} + var d devicesStatements + var lt loginTokenStatements // Create tables before executing migrations so we don't fail if the table is missing, // and THEN prepare statements so we don't fail due to referencing new columns if err = d.execSchema(db); err != nil { return nil, err } + if err = lt.execSchema(db); err != nil { + return nil, err + } + m := sqlutil.NewMigrations() deltas.LoadLastSeenTSIP(m) if err = m.RunDeltas(db, dbProperties); err != nil { @@ -58,8 +69,11 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.prepare(db, serverName); err != nil { return nil, err } + if err = lt.prepare(db); err != nil { + return nil, err + } - return &Database{db, d}, nil + return &Database{db, d, lt, loginTokenLifetime}, nil } // GetDeviceByAccessToken returns the device matching the given access token. @@ -210,3 +224,47 @@ func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) }) } + +// CreateLoginToken generates a token, stores and returns it. The lifetime is +// determined by the loginTokenLifetime given to the Database constructor. +func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { + tok, err := generateLoginToken() + if err != nil { + return nil, err + } + meta := &api.LoginTokenMetadata{ + Token: tok, + Expiration: time.Now().Add(d.loginTokenLifetime), + } + + err = sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.loginTokens.insert(ctx, txn, meta, data) + }) + if err != nil { + return nil, err + } + + return meta, nil +} + +func generateLoginToken() (string, error) { + b := make([]byte, loginTokenByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// RemoveLoginToken removes the named token (and may clean up other expired tokens). +func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { + return sqlutil.WithTransaction(d.db, func(txn *sql.Tx) error { + return d.loginTokens.deleteByToken(ctx, txn, token) + }) +} + +// GetLoginTokenByToken returns the data associated with the given token. +// May return sql.ErrNoRows. +func (d *Database) GetLoginTokenByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + return d.loginTokens.selectByToken(ctx, token) +} diff --git a/userapi/storage/devices/sqlite3/logintoken_table.go b/userapi/storage/devices/sqlite3/logintoken_table.go new file mode 100644 index 0000000000..ef6a75184b --- /dev/null +++ b/userapi/storage/devices/sqlite3/logintoken_table.go @@ -0,0 +1,92 @@ +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package sqlite3 + +import ( + "context" + "database/sql" + + "github.com/matrix-org/dendrite/internal/sqlutil" + "github.com/matrix-org/dendrite/userapi/api" + "github.com/matrix-org/util" +) + +type loginTokenStatements struct { + insertStmt *sql.Stmt + deleteStmt *sql.Stmt + selectStmt *sql.Stmt +} + +// execSchema ensures tables and indices exist. +func (s *loginTokenStatements) execSchema(db *sql.DB) error { + _, err := db.Exec(` +CREATE TABLE IF NOT EXISTS login_tokens ( + -- The random value of the token issued to a user + token TEXT NOT NULL PRIMARY KEY, + -- When the token expires + token_expires_at TIMESTAMP NOT NULL, + + -- The mxid for this account + user_id TEXT NOT NULL +); + +-- This index allows efficient garbage collection of expired tokens. +CREATE INDEX IF NOT EXISTS login_tokens_expiration_idx ON login_tokens(token_expires_at); +`) + return err +} + +// prepare runs statement preparation. +func (s *loginTokenStatements) prepare(db *sql.DB) error { + return sqlutil.StatementList{ + {&s.insertStmt, "INSERT INTO login_tokens(token, token_expires_at, user_id) VALUES ($1, $2, $3)"}, + {&s.deleteStmt, "DELETE FROM login_tokens WHERE token = $1 OR token_expires_at <= CURRENT_TIMESTAMP"}, + {&s.selectStmt, "SELECT user_id FROM login_tokens WHERE token = $1 AND token_expires_at > CURRENT_TIMESTAMP"}, + }.Prepare(db) +} + +// insert adds an already generated token to the database. +func (s *loginTokenStatements) insert(ctx context.Context, txn *sql.Tx, metadata *api.LoginTokenMetadata, data *api.LoginTokenData) error { + stmt := sqlutil.TxStmt(txn, s.insertStmt) + _, err := stmt.ExecContext(ctx, metadata.Token, metadata.Expiration.UTC(), data.UserID) + return err +} + +// deleteByToken removes the named token. +// +// As a simple way to garbage-collect stale tokens, we also remove all expired tokens. +// The login_tokens_expiration_idx index should make that efficient. +func (s *loginTokenStatements) deleteByToken(ctx context.Context, txn *sql.Tx, token string) error { + stmt := sqlutil.TxStmt(txn, s.deleteStmt) + res, err := stmt.ExecContext(ctx, token) + if err != nil { + return err + } + if n, err := res.RowsAffected(); err == nil && n > 1 { + util.GetLogger(ctx).WithField("num_deleted", n).Infof("Deleted %d login tokens (%d likely garbage collected)", n, n-1) + } + return nil +} + +// selectByToken returns the data associated with the given token. May return sql.ErrNoRows. +func (s *loginTokenStatements) selectByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + var data api.LoginTokenData + err := s.selectStmt.QueryRowContext(ctx, token).Scan(&data.UserID) + if err != nil { + return nil, err + } + + return &data, nil +} diff --git a/userapi/storage/devices/sqlite3/storage.go b/userapi/storage/devices/sqlite3/storage.go index 538644837f..7f76166b90 100644 --- a/userapi/storage/devices/sqlite3/storage.go +++ b/userapi/storage/devices/sqlite3/storage.go @@ -19,6 +19,7 @@ import ( "crypto/rand" "database/sql" "encoding/base64" + "time" "github.com/matrix-org/dendrite/internal/sqlutil" "github.com/matrix-org/dendrite/setup/config" @@ -27,30 +28,41 @@ import ( "github.com/matrix-org/gomatrixserverlib" ) -// The length of generated device IDs -var deviceIDByteLength = 6 +const ( + // The length of generated device IDs + deviceIDByteLength = 6 + + loginTokenByteLength = 32 +) // Database represents a device database. type Database struct { - db *sql.DB - writer sqlutil.Writer - devices devicesStatements + db *sql.DB + writer sqlutil.Writer + devices devicesStatements + loginTokens loginTokenStatements + loginTokenLifetime time.Duration } // NewDatabase creates a new device database -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (*Database, error) { +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (*Database, error) { db, err := sqlutil.Open(dbProperties) if err != nil { return nil, err } writer := sqlutil.NewExclusiveWriter() - d := devicesStatements{} + var d devicesStatements + var lt loginTokenStatements // Create tables before executing migrations so we don't fail if the table is missing, // and THEN prepare statements so we don't fail due to referencing new columns if err = d.execSchema(db); err != nil { return nil, err } + if err = lt.execSchema(db); err != nil { + return nil, err + } + m := sqlutil.NewMigrations() deltas.LoadLastSeenTSIP(m) if err = m.RunDeltas(db, dbProperties); err != nil { @@ -59,7 +71,10 @@ func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserver if err = d.prepare(db, writer, serverName); err != nil { return nil, err } - return &Database{db, writer, d}, nil + if err = lt.prepare(db); err != nil { + return nil, err + } + return &Database{db, writer, d, lt, loginTokenLifetime}, nil } // GetDeviceByAccessToken returns the device matching the given access token. @@ -210,3 +225,47 @@ func (d *Database) UpdateDeviceLastSeen(ctx context.Context, localpart, deviceID return d.devices.updateDeviceLastSeen(ctx, txn, localpart, deviceID, ipAddr) }) } + +// CreateLoginToken generates a token, stores and returns it. The lifetime is +// determined by the loginTokenLifetime given to the Database constructor. +func (d *Database) CreateLoginToken(ctx context.Context, data *api.LoginTokenData) (*api.LoginTokenMetadata, error) { + tok, err := generateLoginToken() + if err != nil { + return nil, err + } + meta := &api.LoginTokenMetadata{ + Token: tok, + Expiration: time.Now().Add(d.loginTokenLifetime), + } + + err = d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.loginTokens.insert(ctx, txn, meta, data) + }) + if err != nil { + return nil, err + } + + return meta, nil +} + +func generateLoginToken() (string, error) { + b := make([]byte, loginTokenByteLength) + _, err := rand.Read(b) + if err != nil { + return "", err + } + return base64.RawURLEncoding.EncodeToString(b), nil +} + +// RemoveLoginToken removes the named token (and may clean up other expired tokens). +func (d *Database) RemoveLoginToken(ctx context.Context, token string) error { + return d.writer.Do(d.db, nil, func(txn *sql.Tx) error { + return d.loginTokens.deleteByToken(ctx, txn, token) + }) +} + +// GetLoginTokenByToken returns the data associated with the given token. +// May return sql.ErrNoRows. +func (d *Database) GetLoginTokenByToken(ctx context.Context, token string) (*api.LoginTokenData, error) { + return d.loginTokens.selectByToken(ctx, token) +} diff --git a/userapi/storage/devices/storage.go b/userapi/storage/devices/storage.go index bfce924d99..121f38d56d 100644 --- a/userapi/storage/devices/storage.go +++ b/userapi/storage/devices/storage.go @@ -18,6 +18,7 @@ package devices import ( "fmt" + "time" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/devices/postgres" @@ -26,13 +27,14 @@ import ( ) // NewDatabase opens a new Postgres or Sqlite database (based on dataSourceName scheme) -// and sets postgres connection parameters -func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName) (Database, error) { +// and sets postgres connection parameters. loginTokenLifetime determines how long a +// login token from CreateLoginToken is valid. +func NewDatabase(dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, loginTokenLifetime time.Duration) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName) + return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime) case dbProperties.ConnectionString.IsPostgres(): - return postgres.NewDatabase(dbProperties, serverName) + return postgres.NewDatabase(dbProperties, serverName, loginTokenLifetime) default: return nil, fmt.Errorf("unexpected database type") } diff --git a/userapi/storage/devices/storage_wasm.go b/userapi/storage/devices/storage_wasm.go index f360f98572..3de7880b9c 100644 --- a/userapi/storage/devices/storage_wasm.go +++ b/userapi/storage/devices/storage_wasm.go @@ -16,6 +16,7 @@ package devices import ( "fmt" + "time" "github.com/matrix-org/dendrite/setup/config" "github.com/matrix-org/dendrite/userapi/storage/devices/sqlite3" @@ -25,10 +26,11 @@ import ( func NewDatabase( dbProperties *config.DatabaseOptions, serverName gomatrixserverlib.ServerName, + loginTokenLifetime time.Duration, ) (Database, error) { switch { case dbProperties.ConnectionString.IsSQLite(): - return sqlite3.NewDatabase(dbProperties, serverName) + return sqlite3.NewDatabase(dbProperties, serverName, loginTokenLifetime) case dbProperties.ConnectionString.IsPostgres(): return nil, fmt.Errorf("can't use Postgres implementation") default: diff --git a/userapi/userapi.go b/userapi/userapi.go index 74702020ab..c7e1f66749 100644 --- a/userapi/userapi.go +++ b/userapi/userapi.go @@ -15,6 +15,8 @@ package userapi import ( + "time" + "github.com/gorilla/mux" keyapi "github.com/matrix-org/dendrite/keyserver/api" "github.com/matrix-org/dendrite/setup/config" @@ -26,6 +28,13 @@ import ( "github.com/sirupsen/logrus" ) +// defaultLoginTokenLifetime determines how old a valid token may be. +// +// NOTSPEC: The current spec says "SHOULD be limited to around five +// seconds". Since TCP retries are on the order of 3 s, 5 s sounds very low. +// Synapse uses 2 min (https://github.com/matrix-org/synapse/blob/78d5f91de1a9baf4dbb0a794cb49a799f29f7a38/synapse/handlers/auth.py#L1323-L1325). +const defaultLoginTokenLifetime = 2 * time.Minute + // AddInternalRoutes registers HTTP handlers for the internal API. Invokes functions // on the given input API. func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { @@ -37,11 +46,21 @@ func AddInternalRoutes(router *mux.Router, intAPI api.UserInternalAPI) { func NewInternalAPI( accountDB accounts.Database, cfg *config.UserAPI, appServices []config.ApplicationService, keyAPI keyapi.KeyInternalAPI, ) api.UserInternalAPI { - deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName) + deviceDB, err := devices.NewDatabase(&cfg.DeviceDatabase, cfg.Matrix.ServerName, defaultLoginTokenLifetime) if err != nil { logrus.WithError(err).Panicf("failed to connect to device db") } + return newInternalAPI(accountDB, deviceDB, cfg, appServices, keyAPI) +} + +func newInternalAPI( + accountDB accounts.Database, + deviceDB devices.Database, + cfg *config.UserAPI, + appServices []config.ApplicationService, + keyAPI keyapi.KeyInternalAPI, +) api.UserInternalAPI { return &internal.UserInternalAPI{ AccountDB: accountDB, DeviceDB: deviceDB, diff --git a/userapi/userapi_test.go b/userapi/userapi_test.go index 0141258e68..52ab123773 100644 --- a/userapi/userapi_test.go +++ b/userapi/userapi_test.go @@ -1,4 +1,18 @@ -package userapi_test +// Copyright 2020 The Matrix.org Foundation C.I.C. +// +// Licensed under the Apache License, Version 2.0 (the "License"); +// you may not use this file except in compliance with the License. +// You may obtain a copy of the License at +// +// http://www.apache.org/licenses/LICENSE-2.0 +// +// Unless required by applicable law or agreed to in writing, software +// distributed under the License is distributed on an "AS IS" BASIS, +// WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied. +// See the License for the specific language governing permissions and +// limitations under the License. + +package userapi import ( "context" @@ -6,15 +20,16 @@ import ( "net/http" "reflect" "testing" + "time" "github.com/gorilla/mux" "github.com/matrix-org/dendrite/internal/httputil" "github.com/matrix-org/dendrite/internal/test" "github.com/matrix-org/dendrite/setup/config" - "github.com/matrix-org/dendrite/userapi" "github.com/matrix-org/dendrite/userapi/api" "github.com/matrix-org/dendrite/userapi/inthttp" "github.com/matrix-org/dendrite/userapi/storage/accounts" + "github.com/matrix-org/dendrite/userapi/storage/devices" "github.com/matrix-org/gomatrixserverlib" "golang.org/x/crypto/bcrypt" ) @@ -23,31 +38,41 @@ const ( serverName = gomatrixserverlib.ServerName("example.com") ) -func MustMakeInternalAPI(t *testing.T) (api.UserInternalAPI, accounts.Database) { - accountDB, err := accounts.NewDatabase(&config.DatabaseOptions{ - ConnectionString: "file::memory:", - }, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS) +type apiTestOpts struct { + loginTokenLifetime time.Duration +} + +func MustMakeInternalAPI(t *testing.T, opts apiTestOpts) (api.UserInternalAPI, accounts.Database) { + if opts.loginTokenLifetime == 0 { + opts.loginTokenLifetime = defaultLoginTokenLifetime + } + dbopts := &config.DatabaseOptions{ + ConnectionString: "file::memory:", + MaxOpenConnections: 1, + MaxIdleConnections: 1, + } + accountDB, err := accounts.NewDatabase(dbopts, serverName, bcrypt.MinCost, config.DefaultOpenIDTokenLifetimeMS) if err != nil { t.Fatalf("failed to create account DB: %s", err) } + deviceDB, err := devices.NewDatabase(dbopts, serverName, opts.loginTokenLifetime) + if err != nil { + t.Fatalf("failed to create device DB: %s", err) + } + cfg := &config.UserAPI{ - DeviceDatabase: config.DatabaseOptions{ - ConnectionString: "file::memory:", - MaxOpenConnections: 1, - MaxIdleConnections: 1, - }, Matrix: &config.Global{ ServerName: serverName, }, } - return userapi.NewInternalAPI(accountDB, cfg, nil, nil), accountDB + return newInternalAPI(accountDB, deviceDB, cfg, nil, nil), accountDB } func TestQueryProfile(t *testing.T) { aliceAvatarURL := "mxc://example.com/alice" aliceDisplayName := "Alice" - userAPI, accountDB := MustMakeInternalAPI(t) + userAPI, accountDB := MustMakeInternalAPI(t, apiTestOpts{}) _, err := accountDB.CreateAccount(context.TODO(), "alice", "foobar", "") if err != nil { t.Fatalf("failed to make account: %s", err) @@ -106,7 +131,7 @@ func TestQueryProfile(t *testing.T) { t.Run("HTTP API", func(t *testing.T) { router := mux.NewRouter().PathPrefix(httputil.InternalPathPrefix).Subrouter() - userapi.AddInternalRoutes(router, userAPI) + AddInternalRoutes(router, userAPI) apiURL, cancel := test.ListenAndServe(t, router, false) defer cancel() httpAPI, err := inthttp.NewUserAPIClient(apiURL, &http.Client{}) @@ -119,3 +144,111 @@ func TestQueryProfile(t *testing.T) { runCases(userAPI) }) } + +func TestLoginToken(t *testing.T) { + ctx := context.Background() + + t.Run("tokenLoginFlow", func(t *testing.T) { + userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{}) + + t.Log("Creating a login token like the SSO callback would...") + + creq := api.PerformLoginTokenCreationRequest{ + Data: api.LoginTokenData{UserID: "auser"}, + } + var cresp api.PerformLoginTokenCreationResponse + if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { + t.Fatalf("PerformLoginTokenCreation failed: %v", err) + } + + if cresp.Metadata.Token == "" { + t.Errorf("PerformLoginTokenCreation Token: got %q, want non-empty", cresp.Metadata.Token) + } + if cresp.Metadata.Expiration.Before(time.Now()) { + t.Errorf("PerformLoginTokenCreation Expiration: got %v, want non-expired", cresp.Metadata.Expiration) + } + + t.Log("Querying the login token like /login with m.login.token would...") + + qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} + var qresp api.QueryLoginTokenResponse + if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { + t.Fatalf("QueryLoginToken failed: %v", err) + } + + if qresp.Data == nil { + t.Errorf("QueryLoginToken Data: got %v, want non-nil", qresp.Data) + } + if want := "auser"; qresp.Data.UserID == "" { + t.Errorf("QueryLoginToken UserID: got %q, want %q", qresp.Data.UserID, want) + } + + t.Log("Deleting the login token like /login with m.login.token would...") + + dreq := api.PerformLoginTokenDeletionRequest{Token: cresp.Metadata.Token} + var dresp api.PerformLoginTokenDeletionResponse + if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { + t.Fatalf("PerformLoginTokenDeletion failed: %v", err) + } + }) + + t.Run("expiredTokenIsNotReturned", func(t *testing.T) { + userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{loginTokenLifetime: -1 * time.Second}) + + creq := api.PerformLoginTokenCreationRequest{ + Data: api.LoginTokenData{UserID: "auser"}, + } + var cresp api.PerformLoginTokenCreationResponse + if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { + t.Fatalf("PerformLoginTokenCreation failed: %v", err) + } + + qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} + var qresp api.QueryLoginTokenResponse + if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { + t.Fatalf("QueryLoginToken failed: %v", err) + } + + if qresp.Data != nil { + t.Errorf("QueryLoginToken Data: got %v, want nil", qresp.Data) + } + }) + + t.Run("deleteWorks", func(t *testing.T) { + userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{}) + + creq := api.PerformLoginTokenCreationRequest{ + Data: api.LoginTokenData{UserID: "auser"}, + } + var cresp api.PerformLoginTokenCreationResponse + if err := userAPI.PerformLoginTokenCreation(ctx, &creq, &cresp); err != nil { + t.Fatalf("PerformLoginTokenCreation failed: %v", err) + } + + dreq := api.PerformLoginTokenDeletionRequest{Token: cresp.Metadata.Token} + var dresp api.PerformLoginTokenDeletionResponse + if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { + t.Fatalf("PerformLoginTokenDeletion failed: %v", err) + } + + qreq := api.QueryLoginTokenRequest{Token: cresp.Metadata.Token} + var qresp api.QueryLoginTokenResponse + if err := userAPI.QueryLoginToken(ctx, &qreq, &qresp); err != nil { + t.Fatalf("QueryLoginToken failed: %v", err) + } + + if qresp.Data != nil { + t.Errorf("QueryLoginToken Data: got %v, want nil", qresp.Data) + } + }) + + t.Run("deleteUnknownIsNoOp", func(t *testing.T) { + userAPI, _ := MustMakeInternalAPI(t, apiTestOpts{}) + + dreq := api.PerformLoginTokenDeletionRequest{Token: "non-existent token"} + var dresp api.PerformLoginTokenDeletionResponse + if err := userAPI.PerformLoginTokenDeletion(ctx, &dreq, &dresp); err != nil { + t.Fatalf("PerformLoginTokenDeletion failed: %v", err) + } + }) +}