Compare commits

...

4 Commits

Author SHA1 Message Date
9b306c71a9 Improve session tests 2018-10-02 22:24:18 +02:00
048c474cf6 Make index template const 2018-10-02 22:12:16 +02:00
d36b79e094 Add test for empty Moment slice 2018-10-02 22:07:55 +02:00
c9925a4320 Extract Reverse func 2018-10-02 22:06:17 +02:00
7 changed files with 112 additions and 16 deletions

View File

@ -8,7 +8,7 @@ import (
"time" "time"
) )
var indexPageTemplate = `<html> const IndexPageTemplate = `<html>
<head> <head>
<meta charset="UTF-8"> <meta charset="UTF-8">
<meta name="viewport" <meta name="viewport"
@ -106,7 +106,6 @@ var indexPageTemplate = `<html>
</html> </html>
` `
type indexHandler struct { type indexHandler struct {
DB *bbolt.DB DB *bbolt.DB
} }
@ -114,6 +113,7 @@ type indexHandler struct {
func (h *indexHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *indexHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
sess, err := NewSession(w, r) sess, err := NewSession(w, r)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
log.Printf("Error loading session: %s", err) log.Printf("Error loading session: %s", err)
return return
} }
@ -126,7 +126,9 @@ func (h *indexHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
moments, err := loadMoments(h.DB, time.Now().Format("2006-01-02")) moments, err := loadMoments(h.DB, time.Now().Format("2006-01-02"))
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
log.Println(err) log.Println(err)
return
} }
type indexPageInfo struct { type indexPageInfo struct {
@ -139,23 +141,21 @@ func (h *indexHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
indexPage.Me = sess.Me indexPage.Me = sess.Me
if len(moments) > 0 { if len(moments) > 0 {
a := moments Reverse(moments)
for i := len(a)/2 - 1; i >= 0; i-- {
opp := len(a) - 1 - i
a[i], a[opp] = a[opp], a[i]
}
lastMoment := moments[0] lastMoment := moments[0]
indexPage.LastMomentSeconds = lastMoment.Time.Unix() indexPage.LastMomentSeconds = lastMoment.Time.Unix()
} }
t, err := template.New("index").Parse(indexPageTemplate) t, err := template.New("index").Parse(IndexPageTemplate)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
log.Println(err) log.Println(err)
return return
} }
err = t.Execute(w, indexPage) err = t.Execute(w, indexPage)
if err != nil { if err != nil {
http.Error(w, err.Error(), http.StatusInternalServerError)
log.Println(err) log.Println(err)
return return
} }

View File

@ -21,20 +21,20 @@ func performIndieauthCallback(state, code string, sess *Session) (bool, *indieau
} }
func (h *IndieAuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func (h *IndieAuthHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
log.Println(r.URL.Path)
sess, err := NewSession(w, r) sess, err := NewSession(w, r)
if err != nil { if err != nil {
http.Error(w, err.Error(), 500) http.Error(w, err.Error(), 500)
return return
} }
defer sess.Flush() defer sess.Flush()
if sess.LoggedIn {
http.Redirect(w, r, "/", 302)
return
}
if r.Method == http.MethodGet { if r.Method == http.MethodGet {
if sess.LoggedIn {
http.Redirect(w, r, "/", 302)
return
}
if r.URL.Path == "" { if r.URL.Path == "" {
fmt.Fprint(w, `<!doctype html> fmt.Fprint(w, `<!doctype html>
<html> <html>

View File

@ -9,6 +9,10 @@ import (
bolt "go.etcd.io/bbolt" bolt "go.etcd.io/bbolt"
) )
func init() {
log.SetFlags(log.LstdFlags|log.Lshortfile)
}
const DBFilename = "./moments.db" const DBFilename = "./moments.db"
// Moment is the main information this servers remembers // Moment is the main information this servers remembers

View File

@ -78,6 +78,7 @@ func getSessionCookie(w http.ResponseWriter, r *http.Request) (string, error) {
Name: "session", Name: "session",
Value: sessionVar, Value: sessionVar,
Expires: time.Now().Add(24 * time.Hour), Expires: time.Now().Add(24 * time.Hour),
Path: "/",
} }
http.SetCookie(w, newCookie) http.SetCookie(w, newCookie)

View File

@ -24,6 +24,9 @@ func TestGetSessionCookieMissingCookie(t *testing.T) {
if c.Value != sessionKey { if c.Value != sessionKey {
t.Errorf("Wrong sessionKey %q != %q", c.Value, sessionKey) t.Errorf("Wrong sessionKey %q != %q", c.Value, sessionKey)
} }
if c.Path != "/" {
t.Errorf("Wrong cookiepath %q != %q", c.Path, "/")
}
break break
} }
} }
@ -33,7 +36,7 @@ func TestGetSessionCookieCookieSet(t *testing.T) {
mySessionKey := "12341234" mySessionKey := "12341234"
r, _ := http.NewRequest("GET", "/", nil) r, _ := http.NewRequest("GET", "/", nil)
cookie := &http.Cookie{Name: "session", Value: mySessionKey} cookie := &http.Cookie{Name: "session", Value: mySessionKey, Path: r.URL.Path}
r.AddCookie(cookie) r.AddCookie(cookie)
w := httptest.NewRecorder() w := httptest.NewRecorder()
@ -45,4 +48,48 @@ func TestGetSessionCookieCookieSet(t *testing.T) {
if mySessionKey != sessionKey { if mySessionKey != sessionKey {
t.Errorf("getSessionKey didn't fetch sessionKey from \"session\" cookie") t.Errorf("getSessionKey didn't fetch sessionKey from \"session\" cookie")
} }
cookies := w.Result().Cookies()
for _, c := range cookies {
if c.Name == "session" {
if c.Value != sessionKey {
t.Errorf("Wrong sessionKey %q != %q", c.Value, sessionKey)
}
if c.Path != "" {
t.Errorf("Wrong cookiepath %q != %q", c.Path, "")
}
break
}
}
}
func TestGetSessionCookieCookieSetAuth(t *testing.T) {
mySessionKey := "12341234"
r, _ := http.NewRequest("GET", "/auth/", nil)
cookie := &http.Cookie{Name: "session", Value: mySessionKey, Path: r.URL.Path}
r.AddCookie(cookie)
w := httptest.NewRecorder()
sessionKey, err := getSessionCookie(w, r)
if err != nil {
t.Errorf("err != nil in getSessionCookie")
}
if mySessionKey != sessionKey {
t.Errorf("getSessionKey didn't fetch sessionKey from \"session\" cookie")
}
cookies := w.Result().Cookies()
for _, c := range cookies {
if c.Name == "session" {
if c.Value != sessionKey {
t.Errorf("Wrong sessionKey %q != %q", c.Value, sessionKey)
}
if c.Path != "" {
t.Errorf("Wrong cookiepath %q != %q", c.Path, "")
}
break
}
}
} }

11
util.go
View File

@ -12,4 +12,13 @@ func RandStringBytes(n int) string {
b[i] = letterBytes[rand.Intn(len(letterBytes))] b[i] = letterBytes[rand.Intn(len(letterBytes))]
} }
return string(b) return string(b)
} }
func Reverse(moments []Moment) {
a := moments
for i := len(a)/2 - 1; i >= 0; i-- {
opp := len(a) - 1 - i
a[i], a[opp] = a[opp], a[i]
}
}

35
util_test.go Normal file
View File

@ -0,0 +1,35 @@
package main
import (
"testing"
"time"
)
func TestReverseEmpty(t *testing.T) {
var moments []Moment
Reverse(moments)
}
func TestReverse(t *testing.T) {
moments := []Moment{
{Key: "2018-01-01", Memo: "test", Time: time.Now()},
{Key: "2018-01-02", Memo: "test2", Time: time.Now()},
{Key: "2018-01-03", Memo: "test3", Time: time.Now()},
{Key: "2018-01-04", Memo: "test4", Time: time.Now()},
}
Reverse(moments)
if moments[0].Key != "2018-01-04" {
t.Errorf("wrong 1st key %q != %q", moments[0].Key, "2018-01-04")
}
if moments[1].Key != "2018-01-03" {
t.Errorf("wrong 2nd key %q != %q", moments[1].Key, "2018-01-03")
}
if moments[2].Key != "2018-01-02" {
t.Errorf("wrong 3rd key %q != %q", moments[2].Key, "2018-01-02")
}
if moments[3].Key != "2018-01-01" {
t.Errorf("wrong 4th key %q != %q", moments[3].Key, "2018-01-01")
}
}