Extract session code

This commit is contained in:
Peter Stuifzand 2018-07-12 21:00:47 +02:00
parent 0f9752452d
commit 4c59931283

View File

@ -90,19 +90,7 @@ func newMainHandler(backend *memoryBackend) (*mainHandler, error) {
return h, nil return h, nil
} }
func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { func getSessionCookie(w http.ResponseWriter, r *http.Request) string {
conn := pool.Get()
defer conn.Close()
err := r.ParseForm()
if err != nil {
log.Println(err)
http.Error(w, fmt.Sprintf("Bad Request: %s", err.Error()), 400)
return
}
if r.Method == http.MethodGet {
if r.URL.Path == "/" {
c, err := r.Cookie("session") c, err := r.Cookie("session")
sessionVar := util.RandStringBytes(16) sessionVar := util.RandStringBytes(16)
@ -118,14 +106,43 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
sessionVar = c.Value sessionVar = c.Value
} }
return sessionVar
}
func loadSession(sessionVar string, conn redis.Conn) (session, error) {
var sess session var sess session
sessionKey := "session:" + sessionVar sessionKey := "session:" + sessionVar
data, err := redis.Values(conn.Do("HGETALL", sessionKey)) data, err := redis.Values(conn.Do("HGETALL", sessionKey))
if err != nil { if err != nil {
fmt.Fprintf(w, "ERROR: %q\n", err) return sess, err
return
} }
err = redis.ScanStruct(data, &sess) err = redis.ScanStruct(data, &sess)
if err != nil {
return sess, err
}
return sess, nil
}
func saveSession(sessionVar string, sess *session, conn redis.Conn) error {
_, err := conn.Do("HMSET", redis.Args{}.Add("session:"+sessionVar).AddFlat(sess)...)
return err
}
func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
conn := pool.Get()
defer conn.Close()
err := r.ParseForm()
if err != nil {
log.Println(err)
http.Error(w, fmt.Sprintf("Bad Request: %s", err.Error()), 400)
return
}
if r.Method == http.MethodGet {
if r.URL.Path == "/" {
sessionVar := getSessionCookie(w, r)
sess, err := loadSession(sessionVar, conn)
if err != nil { if err != nil {
fmt.Fprintf(w, "ERROR: %q\n", err) fmt.Fprintf(w, "ERROR: %q\n", err)
return return
@ -146,19 +163,9 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
http.Redirect(w, r, "/", 302) http.Redirect(w, r, "/", 302)
return return
} }
sessionVar := c.Value sessionVar := c.Value
var sess session sess, err := loadSession(sessionVar, conn)
sessionKey := "session:" + sessionVar
data, err := redis.Values(conn.Do("HGETALL", sessionKey))
if err != nil {
fmt.Fprintf(w, "ERROR: %q\n", err)
return
}
err = redis.ScanStruct(data, &sess)
if err != nil {
fmt.Fprintf(w, "ERROR: %q\n", err)
return
}
state := r.Form.Get("state") state := r.Form.Get("state")
if state != sess.State { if state != sess.State {
@ -205,7 +212,7 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
sess.Me = authResponse.Me sess.Me = authResponse.Me
sess.LoggedIn = true sess.LoggedIn = true
conn.Do("HMSET", redis.Args{}.Add(sessionKey).AddFlat(sess)...) saveSession(sessionVar, &sess, conn)
http.Redirect(w, r, "/", 302) http.Redirect(w, r, "/", 302)
return return
} else { } else {
@ -219,18 +226,7 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
return return
} }
sessionVar := c.Value sessionVar := c.Value
var sess session sess, err := loadSession(sessionVar, conn)
sessionKey := "session:" + sessionVar
data, err := redis.Values(conn.Do("HGETALL", sessionKey))
if err != nil {
fmt.Fprintf(w, "ERROR: %q\n", err)
return
}
err = redis.ScanStruct(data, &sess)
if err != nil {
fmt.Fprintf(w, "ERROR: %q\n", err)
return
}
var page settingsPage var page settingsPage
page.Session = sess page.Session = sess
@ -286,7 +282,7 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
ClientID: clientID, ClientID: clientID,
LoggedIn: false, LoggedIn: false,
} }
conn.Do("HMSET", redis.Args{}.Add("session:"+sessionVar).AddFlat(&sess)...) saveSession(sessionVar, &sess, conn)
q := authURL.Query() q := authURL.Query()
q.Add("response_type", "id") q.Add("response_type", "id")