From 103574716ab49142054ea4c83484f1628ec23fed Mon Sep 17 00:00:00 2001 From: Robin Muhia Date: Wed, 24 Jul 2024 15:16:38 +0300 Subject: [PATCH] feat: refactor refresh of access and refresh tokens --- client.go | 57 ++++++++++++++++++------------------------------------- 1 file changed, 18 insertions(+), 39 deletions(-) diff --git a/client.go b/client.go index 7f766e3..c6910ca 100644 --- a/client.go +++ b/client.go @@ -26,10 +26,7 @@ var ( // accessTokenTimeout shows the access token expiry time. // After the access token expires, one is required to obtain a new one - accessTokenTimeout = 60 * time.Minute - - // refreshTokenTimeout shows the refresh token expiry time - refreshTokenTimeout = 24 * time.Hour + accessTokenTimeout = 59 * time.Minute ) // AuthServerImpl defines the methods provided by @@ -44,8 +41,7 @@ type client struct { authServer AuthServerImpl client *http.Client - refreshToken string - refreshTokenTicker *time.Ticker + refreshToken string accessToken string accessTokenTicker *time.Ticker @@ -86,32 +82,23 @@ func mustNewClient(authServer AuthServerImpl) *client { return client } -// executed as a go routine to update the api tokens when they timeout +// executed as a go routine to update access and refresh token func (s *client) background() { - for { - select { - case t := <-s.refreshTokenTicker.C: - logrus.Println("SIL Comms Refresh Token updated at: ", t) - err := s.login() - if err != nil { - s.authFailed = true - } - s.authFailed = false - - case t := <-s.accessTokenTicker.C: - logrus.Println("SIL Comms Access Token updated at: ", t) - err := s.refreshAccessToken() - if err != nil { - s.authFailed = true - } + for t := range s.accessTokenTicker.C { + logrus.Println("SIL Comms Access Token updated at: ", t) + err := s.refreshAccessToken() + if err != nil { + s.authFailed = true + } else { s.authFailed = false } } } // setAccessToken sets the access token and updates the ticker timer -func (s *client) setAccessToken(token string) { - s.accessToken = token +func (s *client) setRefreshAndAccessToken(token *TokenResponse) { + s.accessToken = token.Access + s.refreshToken = token.Refresh if s.accessTokenTicker != nil { s.accessTokenTicker.Reset(accessTokenTimeout) } else { @@ -119,16 +106,6 @@ func (s *client) setAccessToken(token string) { } } -// setRefreshToken sets the access token and updates the ticker timer -func (s *client) setRefreshToken(token string) { - s.refreshToken = token - if s.refreshTokenTicker != nil { - s.refreshTokenTicker.Reset(refreshTokenTimeout) - } else { - s.refreshTokenTicker = time.NewTicker(refreshTokenTimeout) - } -} - // login uses the provided credentials to login to the authserver backend // It obtains the necessary tokens required to make authenticated requests func (s *client) login() error { @@ -149,12 +126,13 @@ func (s *client) login() error { Refresh: resp.RefreshToken, } - s.setRefreshToken(tokens.Refresh) - s.setAccessToken(tokens.Access) + s.setRefreshAndAccessToken(&tokens) return nil } +// refreshAccessToken makes a request to get +// new access and refresh tokens func (s *client) refreshAccessToken() error { ctx := context.Background() resp, err := s.authServer.RefreshToken(ctx, s.refreshToken) @@ -163,10 +141,11 @@ func (s *client) refreshAccessToken() error { } tokens := TokenResponse{ - Access: resp.AccessToken, + Access: resp.AccessToken, + Refresh: resp.RefreshToken, } - s.setAccessToken(tokens.Access) + s.setRefreshAndAccessToken(&tokens) return nil }