diff --git a/cmd/ek/main.go b/cmd/ek/main.go index 76b17ef..cdc815c 100644 --- a/cmd/ek/main.go +++ b/cmd/ek/main.go @@ -242,7 +242,6 @@ func channelID(sub microsub.Microsub, channelNameOrID string) (string, error) { // we encountered an error, so we are not sure if it worked return channelNameOrID, err } - for _, c := range channels { if c.Name == channelNameOrID { return c.UID, nil diff --git a/cmd/eksterd/http.go b/cmd/eksterd/http.go index beeddd3..a165eff 100644 --- a/cmd/eksterd/http.go +++ b/cmd/eksterd/http.go @@ -29,9 +29,11 @@ import ( "log" "net/http" "net/url" + "os" "strings" "time" + "github.com/cristalhq/jwt/v4" "p83.nl/go/ekster/pkg/indieauth" "p83.nl/go/ekster/pkg/microsub" "p83.nl/go/ekster/pkg/util" @@ -44,7 +46,7 @@ import ( var templates embed.FS type mainHandler struct { - Backend *memoryBackend + Backend microsub.Microsub BaseURL string TemplateDir string pool *redis.Pool @@ -111,6 +113,12 @@ type authRequest struct { AccessToken string `redis:"access_token"` } +type micropubClaims struct { + jwt.RegisteredClaims + Channel string + Me string +} + func newMainHandler(backend *memoryBackend, baseURL, templateDir string, pool *redis.Pool) (*mainHandler, error) { h := &mainHandler{Backend: backend} @@ -164,11 +172,13 @@ func loadSession(sessionVar string, conn redis.Conn) (session, error) { sessionKey := "session:" + sessionVar data, err := redis.Values(conn.Do("HGETALL", sessionKey)) if err != nil { + log.Println(err) return sess, err } err = redis.ScanStruct(data, &sess) if err != nil { + log.Println(err) return sess, err } @@ -230,18 +240,19 @@ func verifyAuthCode(code, redirectURI, authEndpoint, clientID string) (bool, *au return false, nil, fmt.Errorf("unknown content-type %q while verifying authorization_code", contentType) } -func isLoggedIn(backend *memoryBackend, sess *session) bool { +func isLoggedIn(sess *session) bool { if !sess.LoggedIn { return false } - if !backend.AuthEnabled { - return true - } - - if sess.Me != backend.Me { - return false - } + // FIXME: Do we need this? + // if !backend.AuthEnabled { + // return true + // } + // + // if sess.Me != backend.Me { + // return false + // } return true } @@ -251,7 +262,6 @@ func performIndieauthCallback(clientID string, r *http.Request, sess *session) ( if state != sess.State { return false, &authResponse{}, fmt.Errorf("mismatched state") } - code := r.Form.Get("code") return verifyAuthCode(code, sess.RedirectURI, sess.AuthorizationEndpoint, clientID) } @@ -377,7 +387,7 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - if !isLoggedIn(h.Backend, &sess) { + if !isLoggedIn(&sess) { w.WriteHeader(401) fmt.Fprintf(w, "Unauthorized") return @@ -402,14 +412,14 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { for _, v := range page.Channels { if v.UID == currentChannel { page.CurrentChannel = v - if setting, e := h.Backend.Settings[v.UID]; e { - page.CurrentSetting = setting - } else { - page.CurrentSetting = channelSetting{} - } - if page.CurrentSetting.ChannelType == "" { - page.CurrentSetting.ChannelType = "postgres-stream" - } + // if setting, e := h.Backend.Settings[v.UID]; e { + // page.CurrentSetting = setting + // } else { + page.CurrentSetting = channelSetting{} + // } + // if page.CurrentSetting.ChannelType == "" { + page.CurrentSetting.ChannelType = "postgres-stream" + // } page.ExcludedTypeNames = map[string]string{ "repost": "Reposts", "like": "Likes", @@ -448,7 +458,7 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - if !isLoggedIn(h.Backend, &sess) { + if !isLoggedIn(&sess) { w.WriteHeader(401) fmt.Fprintf(w, "Unauthorized") return @@ -476,7 +486,7 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - if !isLoggedIn(h.Backend, &sess) { + if !isLoggedIn(&sess) { w.WriteHeader(401) fmt.Fprintf(w, "Unauthorized") return @@ -498,9 +508,39 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { } return } else if r.URL.Path == "/auth" { - // check if we are logged in - // TODO: if not logged in, make sure we get back here + // check arguments for auth + query := r.URL.Query() + responseType := query.Get("response_type") + if responseType != "code" { + http.Error(w, "Unsupported response_type", 400) + return + } + redirectURI := query.Get("redirect_uri") + if _, err := url.Parse(redirectURI); err != nil { + http.Error(w, "Missing redirect_uri", 400) + return + } + + clientID := query.Get("client_id") + if _, err := url.Parse(clientID); err != nil { + http.Error(w, "Missing client_id", 400) + return + } + + me := query.Get("me") + if _, err := url.Parse(me); err != nil { + http.Error(w, "Missing me", 400) + return + } + + state := query.Get("state") + scope := query.Get("scope") + if scope == "" { + scope = "create" + } + + // Check if the client is logging in sessionVar := getSessionCookie(w, r) sess, err := loadSession(sessionVar, conn) @@ -510,7 +550,7 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { return } - if !isLoggedIn(h.Backend, &sess) { + if !isLoggedIn(&sess) { sess.NextURI = r.URL.String() saveSession(sessionVar, &sess, conn) http.Redirect(w, r, "/", http.StatusFound) @@ -520,18 +560,6 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { sess.NextURI = r.URL.String() saveSession(sessionVar, &sess, conn) - query := r.URL.Query() - - // responseType := query.Get("response_type") // TODO: check response_type - me := query.Get("me") - clientID := query.Get("client_id") - redirectURI := query.Get("redirect_uri") - state := query.Get("state") - scope := query.Get("scope") - if scope == "" { - scope = "create" - } - authReq := authRequest{ Me: me, ClientID: clientID, @@ -585,6 +613,18 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { // redirect to endpoint me := r.Form.Get("url") + if !strings.HasPrefix("https://", me) { + me = fmt.Sprintf("https://%s", me) + } + meURL, err := url.Parse(me) + if err != nil { + http.Error(w, fmt.Sprintf("Not a url: %s", me), http.StatusBadRequest) + return + } + if meURL.Path == "" { + meURL.Path = "/" + } + me = meURL.String() endpoints, err := getEndpoints(me) if err != nil { @@ -695,17 +735,31 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { fmt.Fprintf(w, "ERROR: %q", err) return } - token := util.RandStringBytes(32) - _, err = conn.Do("HMSET", redis.Args{}.Add("token:"+token).AddFlat(&auth)...) + + key := []byte(os.Getenv("APP_SECRET")) + signer, err := jwt.NewSignerHS(jwt.HS256, key) if err != nil { - log.Println(err) - fmt.Fprintf(w, "ERROR: %q", err) + log.Printf("could not create signer for jwt: %s", err) + http.Error(w, err.Error(), http.StatusInternalServerError) + return + } + builder := jwt.NewBuilder(signer) + token, err := builder.Build(µpubClaims{ + RegisteredClaims: jwt.RegisteredClaims{ + IssuedAt: jwt.NewNumericDate(time.Now()), + }, + Channel: auth.Channel, + Me: auth.Me, + }) + if err != nil { + log.Printf("could not create signer for jwt: %s", err) + http.Error(w, err.Error(), http.StatusInternalServerError) return } res := authTokenResponse{ Me: auth.Me, - AccessToken: token, + AccessToken: token.String(), TokenType: "Bearer", Scope: auth.Scope, } diff --git a/cmd/eksterd/http_test.go b/cmd/eksterd/http_test.go new file mode 100644 index 0000000..810e835 --- /dev/null +++ b/cmd/eksterd/http_test.go @@ -0,0 +1,95 @@ +/* + * Ekster is a microsub server + * Copyright (c) 2021 The Ekster authors + * + * This program is free software: you can redistribute it and/or modify + * it under the terms of the GNU General Public License as published by + * the Free Software Foundation, either version 3 of the License, or + * (at your option) any later version. + * + * This program is distributed in the hope that it will be useful, + * but WITHOUT ANY WARRANTY; without even the implied warranty of + * MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the + * GNU General Public License for more details. + * + * You should have received a copy of the GNU General Public License + * along with this program. If not, see . + */ + +package main + +import ( + "io/ioutil" + "math/rand" + "net/http" + "net/http/httptest" + "net/url" + "testing" + + "github.com/gomodule/redigo/redis" + "github.com/rafaeljusto/redigomock/v3" + "github.com/stretchr/testify/assert" + "p83.nl/go/ekster/pkg/server" + "p83.nl/go/ekster/pkg/util" +) + +func init() { + rand.Seed(1) +} + +func TestMainHandler_ServeHTTP_NoArgs(t *testing.T) { + conn := redigomock.NewConn() + pool := &redis.Pool{ + // Return the same connection mock for each Get() call. + Dial: func() (redis.Conn, error) { return conn, nil }, + MaxIdle: 10, + } + + h := mainHandler{Backend: &memoryBackend{AuthEnabled: true}, BaseURL: "", TemplateDir: "", pool: pool} + r := httptest.NewRequest(http.MethodGet, "/auth", nil) + w := httptest.NewRecorder() + h.ServeHTTP(w, r) + assert.Equal(t, 400, w.Code) +} +func TestMainHandler_ServeHTTP_Args(t *testing.T) { + conn := redigomock.NewConn() + pool := &redis.Pool{ + // Return the same connection mock for each Get() call. + Dial: func() (redis.Conn, error) { return conn, nil }, + MaxIdle: 10, + } + + q := url.Values{} + q.Add("response_type", "code") + q.Add("client_id", "https://example.com/") + q.Add("redirect_uri", "https://example.com/callback") + q.Add("me", "https://p83.nl/") + state := util.RandStringBytes(32) + q.Add("state", state) + q.Add("scope", "create") + h := mainHandler{Backend: &server.NullBackend{}, BaseURL: "", TemplateDir: "", pool: pool} + r := httptest.NewRequest(http.MethodGet, "/auth?"+q.Encode(), nil) + w := httptest.NewRecorder() + + conn.Command("HGETALL", "session:FpLSjFbcXoEFfRsW").ExpectMap(map[string]string{ + "logged_in": "1", + }) + conn.Command( + "HMSET", "state:XVlBzgbaiCMRAjWwhTHctcuAxhxKQFDa", + "me", "https://p83.nl/", + "client_id", "https://example.com/", + "scope", "create", + "redirect_uri", "https://example.com/callback", + "state", "XVlBzgbaiCMRAjWwhTHctcuAxhxKQFDa", + "code", "", + "channel", "", + "access_token", "", + ) + + h.ServeHTTP(w, r) + assert.Equal(t, 302, w.Code) + body, err := ioutil.ReadAll(w.Result().Body) + if assert.NoError(t, err) { + assert.Equal(t, "", string(body)) + } +} diff --git a/cmd/eksterd/micropub.go b/cmd/eksterd/micropub.go index f10776f..17bdc8b 100644 --- a/cmd/eksterd/micropub.go +++ b/cmd/eksterd/micropub.go @@ -24,6 +24,7 @@ import ( "fmt" "log" "net/http" + "os" "strings" "time" @@ -33,6 +34,8 @@ import ( "github.com/gomodule/redigo/redis" "github.com/pkg/errors" "willnorris.com/go/microformats" + + "github.com/cristalhq/jwt/v4" ) type micropubHandler struct { @@ -183,6 +186,11 @@ func getChannelFromAuthorization(r *http.Request, conn redis.Conn) (string, erro token := authHeader[7:] channel, err := redis.String(conn.Do("HGET", "token:"+token, "channel")) if err != nil { + _, err := verifyJWT(token) + if err != nil { + return "", err + } + return "", errors.Wrap(err, "could not get channel for token") } @@ -191,3 +199,28 @@ func getChannelFromAuthorization(r *http.Request, conn redis.Conn) (string, erro return "", fmt.Errorf("could not get channel from authorization") } + +func verifyJWT(token string) (bool, error) { + key := []byte(os.Getenv("APP_SECRET")) + verifier, err := jwt.NewVerifierHS(jwt.HS256, key) + if err != nil { + return false, fmt.Errorf("could not create verifier for jwt: %w", err) + } + + newToken, err := jwt.Parse([]byte(token), verifier) + if err != nil { + return false, fmt.Errorf("could not parse jwt: %w", err) + } + + var newClaims jwt.RegisteredClaims + err = json.Unmarshal(newToken.Claims(), &newClaims) + if err != nil { + return false, fmt.Errorf("could not parse jwt claims: %w", err) + } + + if !newClaims.IsValidAt(time.Now()) { + return false, fmt.Errorf("jwt is not valid now") + } + + return true, nil +} diff --git a/go.mod b/go.mod index 3aa331e..9b5f143 100644 --- a/go.mod +++ b/go.mod @@ -5,12 +5,14 @@ go 1.16 require ( github.com/axgle/mahonia v0.0.0-20180208002826-3358181d7394 github.com/blevesearch/bleve/v2 v2.0.3 + github.com/cristalhq/jwt/v4 v4.0.0-beta // indirect github.com/davecgh/go-spew v1.1.1 // indirect github.com/gilliek/go-opml v1.0.0 github.com/golang-migrate/migrate/v4 v4.15.1 - github.com/gomodule/redigo v1.8.2 + github.com/gomodule/redigo v1.8.5 github.com/lib/pq v1.10.1 github.com/pkg/errors v0.9.1 + github.com/rafaeljusto/redigomock/v3 v3.0.1 // indirect github.com/stretchr/testify v1.7.0 golang.org/x/net v0.0.0-20211013171255-e13a2654a71e willnorris.com/go/microformats v1.1.0 diff --git a/go.sum b/go.sum index 07b746f..d95e2b6 100644 --- a/go.sum +++ b/go.sum @@ -312,6 +312,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:ma github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU= github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY= github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E= +github.com/cristalhq/jwt/v4 v4.0.0-beta h1:TSNUSW7CCvMtX+Rg3vj706BwqIEt9HMvVv6Syifq5mU= +github.com/cristalhq/jwt/v4 v4.0.0-beta/go.mod h1:HnYraSNKDRag1DZP92rYHyrjyQHnVEHPNqesmzs+miQ= github.com/cyphar/filepath-securejoin v0.2.2/go.mod h1:FpkQEhXnPnOthhzymB7CGsFk2G9VLXONKD9G7QGMM+4= github.com/cznic/mathutil v0.0.0-20180504122225-ca4c9f2c1369/go.mod h1:e6NPNENfs9mPDVNRekM7lKScauxd5kXTr1Mfyig6TDM= github.com/d2g/dhcp4 v0.0.0-20170904100407-a1d1b6c41b1c/go.mod h1:Ct2BUK8SB0YC1SMSibvLzxjeJLnrYEVLULFNiHY9YfQ= @@ -496,6 +498,9 @@ github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM= github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q= github.com/gomodule/redigo v1.8.2 h1:H5XSIre1MB5NbPYFp+i1NBbb5qN1W8Y8YAQoAYbkm8k= github.com/gomodule/redigo v1.8.2/go.mod h1:P9dn9mFrCBvWhGE1wpxx6fgq7BAeLBk+UUUzlpkBYO0= +github.com/gomodule/redigo v1.8.3/go.mod h1:P9dn9mFrCBvWhGE1wpxx6fgq7BAeLBk+UUUzlpkBYO0= +github.com/gomodule/redigo v1.8.5 h1:nRAxCa+SVsyjSBrtZmG/cqb6VbTmuRzpg/PoTFlpumc= +github.com/gomodule/redigo v1.8.5/go.mod h1:P9dn9mFrCBvWhGE1wpxx6fgq7BAeLBk+UUUzlpkBYO0= github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ= github.com/google/flatbuffers v2.0.0+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8= @@ -838,6 +843,8 @@ github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4O github.com/prometheus/procfs v0.2.0/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU= github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA= github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU= +github.com/rafaeljusto/redigomock/v3 v3.0.1 h1:AUsXTuf+UEMwVEgRHRDYFFCJ1quS2JVDQmTWypjI5mI= +github.com/rafaeljusto/redigomock/v3 v3.0.1/go.mod h1:51LNR7Q4YFsi0N+CHr7+FC1Jx2lPLzcRHCPlLO2Qbpw= github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4= github.com/remyoudompheng/bigfft v0.0.0-20190728182440-6a916e37a237/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo= diff --git a/pkg/microsub/protocol.go b/pkg/microsub/protocol.go index df5f69e..7915b76 100644 --- a/pkg/microsub/protocol.go +++ b/pkg/microsub/protocol.go @@ -155,6 +155,7 @@ type Microsub interface { ItemSearch(channel, query string) ([]Item, error) Events() (chan sse.Message, error) + RefreshFeeds() } // MarshalJSON encodes an Unread value as JSON diff --git a/pkg/server/null.go b/pkg/server/null.go index a1aee92..aedf301 100644 --- a/pkg/server/null.go +++ b/pkg/server/null.go @@ -27,6 +27,12 @@ import ( type NullBackend struct { } +// RefreshFeeds refreshes feeds +func (b *NullBackend) RefreshFeeds() { + // TODO implement me + // panic("implement me") +} + // ChannelsGetList gets no channels func (b *NullBackend) ChannelsGetList() ([]microsub.Channel, error) { return []microsub.Channel{