From ed90ebbdd316a110c46ff82e334e4d16637bc54d Mon Sep 17 00:00:00 2001 From: Peter Stuifzand Date: Sat, 23 Mar 2019 21:29:48 +0100 Subject: [PATCH] Improve error handling in authentication --- cmd/eksterd/auth.go | 63 +++++++++++++++++++++++++------------------ cmd/eksterd/http.go | 5 +++- cmd/eksterd/main.go | 6 ++++- cmd/eksterd/memory.go | 13 ++++++--- 4 files changed, 55 insertions(+), 32 deletions(-) diff --git a/cmd/eksterd/auth.go b/cmd/eksterd/auth.go index 76f946b..bee5e68 100644 --- a/cmd/eksterd/auth.go +++ b/cmd/eksterd/auth.go @@ -9,90 +9,101 @@ import ( "time" "github.com/gomodule/redigo/redis" + "github.com/pkg/errors" "p83.nl/go/ekster/pkg/auth" ) var authHeaderRegex = regexp.MustCompile("^Bearer (.+)$") -func (b *memoryBackend) cachedCheckAuthToken(conn redis.Conn, header string, r *auth.TokenResponse) bool { +func (b *memoryBackend) cachedCheckAuthToken(conn redis.Conn, header string, r *auth.TokenResponse) (bool, error) { tokens := authHeaderRegex.FindStringSubmatch(header) if len(tokens) != 2 { - log.Println("No token found in the header") - return false + return false, fmt.Errorf("could not find token in header") } key := fmt.Sprintf("token:%s", tokens[1]) authorized, err := getCachedValue(conn, key, r) if err != nil { - log.Println(err) + log.Printf("could not get cached auth token value: %v", err) } if authorized { - return true + return true, nil + } + + authorized, err = b.checkAuthToken(header, r) + if err != nil { + return false, errors.Wrap(err, "could not check auth token") } - authorized = b.checkAuthToken(header, r) if authorized { err = setCachedTokenResponseValue(conn, key, r) if err != nil { - log.Println(err) + log.Printf("could not set cached token response value: %v", err) } - return true + + return true, nil } - return authorized + return authorized, nil } -func (b *memoryBackend) checkAuthToken(header string, token *auth.TokenResponse) bool { - log.Println("Checking auth token") - +func (b *memoryBackend) checkAuthToken(header string, token *auth.TokenResponse) (bool, error) { tokenEndpoint := b.TokenEndpoint req, err := buildValidateAuthTokenRequest(tokenEndpoint, header) if err != nil { - return false + return false, err } client := http.Client{} res, err := client.Do(req) if err != nil { - log.Println(err) - return false + return false, err } - defer res.Body.Close() + defer func() { + err := res.Body.Close() + if err != nil { + log.Printf("could not close http response body: %v", err) + } + }() if res.StatusCode < 200 || res.StatusCode >= 300 { - log.Printf("HTTP StatusCode when verifying token: %d\n", res.StatusCode) - return false + return false, fmt.Errorf("got unsuccessfull http status code while verifying token: %d", res.StatusCode) } dec := json.NewDecoder(res.Body) err = dec.Decode(&token) if err != nil { - log.Printf("Error in json object: %v", err) - return false + return false, errors.Wrap(err, "could not decode json body") } - log.Println("Auth Token: Success") - return true + return true, nil } func buildValidateAuthTokenRequest(tokenEndpoint string, header string) (*http.Request, error) { req, err := http.NewRequest("GET", tokenEndpoint, nil) + if err != nil { + return nil, errors.Wrap(err, "could not create a new request") + } req.Header.Add("Authorization", header) req.Header.Add("Accept", "application/json") - return req, err + + return req, nil } // setCachedTokenResponseValue remembers the value of the auth token response in redis func setCachedTokenResponseValue(conn redis.Conn, key string, r *auth.TokenResponse) error { _, err := conn.Do("HMSET", redis.Args{}.Add(key).AddFlat(r)...) if err != nil { - return fmt.Errorf("error while setting token: %v", err) + return errors.Wrap(err, "could not remember token") + } + _, err = conn.Do("EXPIRE", key, uint64(10*time.Minute/time.Second)) + if err != nil { + return errors.Wrap(err, "could not set expiration for token") } - conn.Do("EXPIRE", key, uint64(10*time.Minute/time.Second)) return nil } @@ -100,7 +111,7 @@ func setCachedTokenResponseValue(conn redis.Conn, key string, r *auth.TokenRespo func getCachedValue(conn redis.Conn, key string, r *auth.TokenResponse) (bool, error) { values, err := redis.Values(conn.Do("HGETALL", key)) if err != nil { - return false, fmt.Errorf("error while getting value from backend: %v", err) + return false, errors.Wrap(err, "could not get value from backend") } if len(values) > 0 { diff --git a/cmd/eksterd/http.go b/cmd/eksterd/http.go index 70ac289..9a9e72c 100644 --- a/cmd/eksterd/http.go +++ b/cmd/eksterd/http.go @@ -127,7 +127,7 @@ func getSessionCookie(w http.ResponseWriter, r *http.Request) string { } http.SetCookie(w, newCookie) - } else { + } else if err == nil { sessionVar = c.Value } @@ -295,6 +295,9 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { if err == http.ErrNoCookie { http.Redirect(w, r, "/", 302) return + } else if err != nil { + http.Error(w, "could not read cookie", 500) + return } sessionVar := c.Value diff --git a/cmd/eksterd/main.go b/cmd/eksterd/main.go index b310c3e..0d49e42 100644 --- a/cmd/eksterd/main.go +++ b/cmd/eksterd/main.go @@ -74,7 +74,11 @@ func WithAuth(handler http.Handler, b *memoryBackend) http.Handler { var token auth.TokenResponse - if !b.AuthTokenAccepted(authorization, &token) { + authorized, err := b.AuthTokenAccepted(authorization, &token) + if err != nil { + log.Printf("token not accepted: %v", err) + } + if !authorized { log.Printf("Token could not be validated") http.Error(w, "Can't validate token", 403) return diff --git a/cmd/eksterd/memory.go b/cmd/eksterd/memory.go index 5d1e1fb..aa76ab2 100644 --- a/cmd/eksterd/memory.go +++ b/cmd/eksterd/memory.go @@ -67,9 +67,14 @@ func (f *fetch2) Fetch(url string) (*http.Response, error) { return Fetch2(url) } -func (b *memoryBackend) AuthTokenAccepted(header string, r *auth.TokenResponse) bool { +func (b *memoryBackend) AuthTokenAccepted(header string, r *auth.TokenResponse) (bool, error) { conn := b.pool.Get() - defer conn.Close() + defer func() { + err := conn.Close() + if err != nil { + log.Printf("could not close redis connection: %v", err) + } + }() return b.cachedCheckAuthToken(conn, header, r) } @@ -495,8 +500,8 @@ func (b *memoryBackend) PreviewURL(previewURL string) (microsub.Timeline, error) } func (b *memoryBackend) MarkRead(channel string, uids []string) error { - timeline := b.getTimeline(channel) - err := timeline.MarkRead(uids) + tl := b.getTimeline(channel) + err := tl.MarkRead(uids) if err != nil { return err