diff --git a/index.go b/index.go index d2f08de..e01264a 100644 --- a/index.go +++ b/index.go @@ -113,6 +113,7 @@ type indexHandler struct { func (h *indexHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { sess, err := NewSession(w, r) if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) log.Printf("Error loading session: %s", err) return } @@ -125,7 +126,9 @@ func (h *indexHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { moments, err := loadMoments(h.DB, time.Now().Format("2006-01-02")) if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) log.Println(err) + return } type indexPageInfo struct { @@ -145,12 +148,14 @@ func (h *indexHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { t, err := template.New("index").Parse(IndexPageTemplate) if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) log.Println(err) return } err = t.Execute(w, indexPage) if err != nil { + http.Error(w, err.Error(), http.StatusInternalServerError) log.Println(err) return } diff --git a/indieauth.go b/indieauth.go index b4721f6..b334544 100644 --- a/indieauth.go +++ b/indieauth.go @@ -21,20 +21,20 @@ func performIndieauthCallback(state, code string, sess *Session) (bool, *indieau } func (h *IndieAuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - log.Println(r.URL.Path) - sess, err := NewSession(w, r) if err != nil { http.Error(w, err.Error(), 500) return } + defer sess.Flush() + if sess.LoggedIn { + http.Redirect(w, r, "/", 302) + return + } + if r.Method == http.MethodGet { - if sess.LoggedIn { - http.Redirect(w, r, "/", 302) - return - } if r.URL.Path == "" { fmt.Fprint(w, ` diff --git a/main.go b/main.go index 52b5962..2c687b9 100644 --- a/main.go +++ b/main.go @@ -9,6 +9,10 @@ import ( bolt "go.etcd.io/bbolt" ) +func init() { + log.SetFlags(log.LstdFlags|log.Lshortfile) +} + const DBFilename = "./moments.db" // Moment is the main information this servers remembers diff --git a/session.go b/session.go index 9f03b0e..75a5fd5 100644 --- a/session.go +++ b/session.go @@ -78,6 +78,7 @@ func getSessionCookie(w http.ResponseWriter, r *http.Request) (string, error) { Name: "session", Value: sessionVar, Expires: time.Now().Add(24 * time.Hour), + Path: "/", } http.SetCookie(w, newCookie) diff --git a/session_test.go b/session_test.go index 05faf45..cbf5bd5 100644 --- a/session_test.go +++ b/session_test.go @@ -24,6 +24,9 @@ func TestGetSessionCookieMissingCookie(t *testing.T) { if c.Value != sessionKey { t.Errorf("Wrong sessionKey %q != %q", c.Value, sessionKey) } + if c.Path != "/" { + t.Errorf("Wrong cookiepath %q != %q", c.Path, "/") + } break } } @@ -33,7 +36,7 @@ func TestGetSessionCookieCookieSet(t *testing.T) { mySessionKey := "12341234" r, _ := http.NewRequest("GET", "/", nil) - cookie := &http.Cookie{Name: "session", Value: mySessionKey} + cookie := &http.Cookie{Name: "session", Value: mySessionKey, Path: r.URL.Path} r.AddCookie(cookie) w := httptest.NewRecorder() @@ -45,4 +48,48 @@ func TestGetSessionCookieCookieSet(t *testing.T) { if mySessionKey != sessionKey { t.Errorf("getSessionKey didn't fetch sessionKey from \"session\" cookie") } + + cookies := w.Result().Cookies() + for _, c := range cookies { + if c.Name == "session" { + if c.Value != sessionKey { + t.Errorf("Wrong sessionKey %q != %q", c.Value, sessionKey) + } + if c.Path != "" { + t.Errorf("Wrong cookiepath %q != %q", c.Path, "") + } + break + } + } +} + +func TestGetSessionCookieCookieSetAuth(t *testing.T) { + mySessionKey := "12341234" + + r, _ := http.NewRequest("GET", "/auth/", nil) + cookie := &http.Cookie{Name: "session", Value: mySessionKey, Path: r.URL.Path} + r.AddCookie(cookie) + + w := httptest.NewRecorder() + sessionKey, err := getSessionCookie(w, r) + if err != nil { + t.Errorf("err != nil in getSessionCookie") + } + + if mySessionKey != sessionKey { + t.Errorf("getSessionKey didn't fetch sessionKey from \"session\" cookie") + } + + cookies := w.Result().Cookies() + for _, c := range cookies { + if c.Name == "session" { + if c.Value != sessionKey { + t.Errorf("Wrong sessionKey %q != %q", c.Value, sessionKey) + } + if c.Path != "" { + t.Errorf("Wrong cookiepath %q != %q", c.Path, "") + } + break + } + } }