From ec510484d845b84380068f8cecc668568c1a7653 Mon Sep 17 00:00:00 2001 From: Peter Stuifzand Date: Mon, 18 Mar 2019 21:47:49 +0100 Subject: [PATCH] Cleanup websub hub (and use postgres) --- cmd/hubserver/handler.go | 239 +++++++++++++++++++ cmd/hubserver/main.go | 380 +----------------------------- cmd/hubserver/storage/postgres.go | 62 +++++ cmd/hubserver/storage/storage.go | 21 ++ db/01_tables.sql | 9 + 5 files changed, 342 insertions(+), 369 deletions(-) create mode 100644 cmd/hubserver/handler.go create mode 100644 cmd/hubserver/storage/postgres.go create mode 100644 cmd/hubserver/storage/storage.go create mode 100644 db/01_tables.sql diff --git a/cmd/hubserver/handler.go b/cmd/hubserver/handler.go new file mode 100644 index 0000000..e331927 --- /dev/null +++ b/cmd/hubserver/handler.go @@ -0,0 +1,239 @@ +package main + +import ( + "crypto/hmac" + "crypto/sha1" + "fmt" + "io/ioutil" + "log" + "net/http" + "net/url" + "os" + "strconv" + "strings" + "time" + + "p83.nl/go/websub-hub/cmd/hubserver/storage" +) + +type subscriptionHandler struct { + store storage.Service + baseURL string +} + +func (handler *subscriptionHandler) handlePublish(w http.ResponseWriter, r *http.Request) error { + topic := r.Form.Get("hub.topic") + log.Printf("publish: topic = %s\n", topic) + + client := &http.Client{} + req, err := http.NewRequest("GET", topic, nil) + if err != nil { + return err + } + req.Header.Add("Accept", "*/*") + res, err := client.Do(req) + if err != nil { + return err + } + defer res.Body.Close() + + feedContentType := res.Header.Get("Content-Type") + feedContent, err := ioutil.ReadAll(res.Body) + if err != nil { + return err + } + + if subs, err := handler.store.Subscribers(topic); err != nil { + for _, sub := range subs { + log.Printf("publish: creating post to %s\n", sub.Callback) + postReq, err := http.NewRequest("POST", sub.Callback, strings.NewReader(string(feedContent))) + if err != nil { + log.Printf("While creating request to %s: %s", sub.Callback, err) + continue + } + postReq.Header.Add("Content-Type", feedContentType) + postReq.Header.Add("Link", + fmt.Sprintf( + "<%s>; rel=hub, <%s>; rel=self", + handler.baseURL, + topic, + )) + if sub.Secret != "" { + mac := hmac.New(sha1.New, []byte(sub.Secret)) + mac.Write(feedContent) + signature := mac.Sum(nil) + postReq.Header.Add("X-Hub-Signature", fmt.Sprintf("sha1=%x", signature)) + } + postRes, err := client.Do(postReq) + if err != nil { + log.Printf("While POSTing to %s: %s", sub.Callback, err) + continue + } + log.Printf("publish: post send to %s\n", sub.Callback) + log.Println("Response:") + _ = postRes.Write(os.Stdout) + } + } + + return nil +} + +func (handler *subscriptionHandler) handleUnsubscription(w http.ResponseWriter, r *http.Request) error { + callback := r.Form.Get("hub.callback") + topic := r.Form.Get("hub.topic") + mode := r.Form.Get("hub.mode") + + if subscribers, err := handler.store.Subscribers(topic); err != nil { + for _, subscriber := range subscribers { + if subscriber.Callback != callback { + continue + } + ourChallenge := randStringBytes(12) + + validationURL, err := url.Parse(callback) + if err != nil { + log.Println(err) + return err + } + q := validationURL.Query() + q.Add("hub.mode", mode) + q.Add("hub.topic", topic) + q.Add("hub.challenge", ourChallenge) + validationURL.RawQuery = q.Encode() + if validateURL(validationURL.String(), ourChallenge) { + err = handler.store.Unsubscribe(topic, callback) + if err != nil { + return err + } + } + } + + w.WriteHeader(200) + _, err = fmt.Fprintf(w, "Unsubscribed\n") + return err + } else { + http.Error(w, "Hub does not handle subscription for topic", 400) + } + return nil +} + +func (handler *subscriptionHandler) handleSubscription(w http.ResponseWriter, r *http.Request) error { + log.Printf("subscription request received: %s %#v\n", r.URL.String(), r.Form) + + callback := r.Form.Get("hub.callback") + topic := r.Form.Get("hub.topic") + secret := r.Form.Get("hub.secret") + leaseSecondsStr := r.Form.Get("hub.lease_seconds") + leaseSeconds, err := strconv.ParseInt(leaseSecondsStr, 10, 64) + if leaseSecondsStr != "" && err != nil { + http.Error(w, "hub.lease_seconds is used, but not a valid integer", 400) + log.Printf("hub.lease_seconds is used, but not a valid integer (%s)\n", leaseSecondsStr) + return err + } + + log.Printf("subscribe: received for topic=%s to callback=%s (lease=%ds)\n", topic, callback, leaseSeconds) + + if _, e := r.Form["hub.lease_seconds"]; !e { + leaseSeconds = 3600 + leaseSecondsStr = "3600" + log.Printf("subscribe: lease_seconds was empty use default %ds\n", leaseSeconds) + } + + callbackURL, err := url.Parse(callback) + if callback == "" || err != nil { + http.Error(w, "Can not parse callback url", 400) + log.Printf("Can not parse callback url: %s\n", callback) + return err + } + + topicURL, err := url.Parse(topic) + if topic == "" || err != nil { + http.Error(w, "Can't parse topic url", 400) + log.Printf("Can't parse topic url: %s\n", topic) + return err + } + + log.Println("subscribe: sending 202 header request accepted") + w.WriteHeader(202) + _, _ = fmt.Fprint(w, "Accepted\r\n") + + go func() { + ourChallenge := randStringBytes(12) + + validationURL := *callbackURL + q := validationURL.Query() + q.Add("hub.mode", "subscribe") + q.Add("hub.topic", topicURL.String()) + q.Add("hub.challenge", ourChallenge) + q.Add("hub.lease_seconds", leaseSecondsStr) + if secret != "" { + q.Add("hub.verify_token", secret) + } + validationURL.RawQuery = q.Encode() + + log.Printf("subscribe: async validation with url %s\n", validationURL.String()) + + if validateURL(validationURL.String(), ourChallenge) { + log.Printf("subscribe: validation valid\n") + _ = handler.store.Subscribe(topicURL.String(), storage.Subscriber{callbackURL.String(), leaseSeconds, secret, time.Now()}) + } else { + log.Printf("subscribe: validation failed\n") + } + }() + + return nil +} + +func validateURL(url, challenge string) bool { + client := http.Client{} + + req, err := http.NewRequest(http.MethodGet, url, nil) + if err != nil { + log.Println(err) + return false + } + res, err := client.Do(req) + if err != nil { + log.Println(err) + return false + } + defer res.Body.Close() + + body, err := ioutil.ReadAll(res.Body) + if err != nil { + log.Println(err) + return false + } + + return strings.Contains(string(body), challenge) +} + +func (handler *subscriptionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { + if r.Header.Get("Content-Type") != "application/x-www-form-urlencoded" { + http.Error(w, "Bad Request", 400) + return + } + + err := r.ParseForm() + if err != nil { + http.Error(w, "Bad Request", 400) + return + } + + mode := r.Form.Get("hub.mode") + + if mode == "subscribe" { + err = handler.handleSubscription(w, r) + return + } else if mode == "unsubscribe" { + err = handler.handleUnsubscription(w, r) + return + } else if mode == "publish" { + log.Println("hub.mode=publish received") + err = handler.handlePublish(w, r) + return + } else { + http.Error(w, "Unknown hub.mode", 400) + return + } +} diff --git a/cmd/hubserver/main.go b/cmd/hubserver/main.go index 9f72e93..0f0e58c 100644 --- a/cmd/hubserver/main.go +++ b/cmd/hubserver/main.go @@ -1,21 +1,14 @@ package main import ( - "crypto/hmac" - "crypto/sha1" - "encoding/json" "flag" "fmt" - "io/ioutil" "log" "math/rand" "net/http" - "net/url" - "os" - "strconv" - "strings" - "sync" "time" + + "p83.nl/go/websub-hub/cmd/hubserver/storage" ) func init() { @@ -32,376 +25,25 @@ func randStringBytes(n int) string { return string(b) } -type Subscriber struct { - Callback string - LeaseSeconds int64 - Secret string - Created time.Time -} - type Stat struct { Updates int LastUpdate time.Time } -type subscriptionHandler struct { - LockSubs sync.Mutex - Subscribers map[string][]Subscriber - - LockStats sync.Mutex - Stats map[string]Stat -} - -func (handler *subscriptionHandler) handlePublish(w http.ResponseWriter, r *http.Request) error { - topic := r.Form.Get("hub.topic") - log.Printf("publish: topic = %s\n", topic) - - client := &http.Client{} - req, err := http.NewRequest("GET", topic, nil) - req.Header.Add("Accept", "*/*") - res, err := client.Do(req) - if err != nil { - return err - } - defer res.Body.Close() - - feedContentType := res.Header.Get("Content-Type") - feedContent, err := ioutil.ReadAll(res.Body) - if err != nil { - return err - } - - handler.incStat("published") - - if subs, e := handler.Subscribers[topic]; e { - for _, sub := range subs { - - handler.incStat("publish.post") - log.Printf("publish: creating post to %s\n", sub.Callback) - postReq, err := http.NewRequest("POST", sub.Callback, strings.NewReader(string(feedContent))) - if err != nil { - log.Printf("While creating request to %s: %s", sub.Callback, err) - continue - } - postReq.Header.Add("Content-Type", feedContentType) - postReq.Header.Add("Link", - fmt.Sprintf( - "<%s>; rel=hub, <%s>; rel=self", - "https://hub.stuifzandapp.com/", - topic, - )) - if sub.Secret != "" { - mac := hmac.New(sha1.New, []byte(sub.Secret)) - mac.Write(feedContent) - signature := mac.Sum(nil) - postReq.Header.Add("X-Hub-Signature", fmt.Sprintf("sha1=%x", signature)) - } - postRes, err := client.Do(postReq) - if err != nil { - log.Printf("While POSTing to %s: %s", sub.Callback, err) - continue - } - log.Printf("publish: post send to %s\n", sub.Callback) - log.Println("Response:") - postRes.Write(os.Stdout) - - } - } else { - log.Println("Topic not found") - } - - return nil -} - -func (handler *subscriptionHandler) handleUnsubscription(w http.ResponseWriter, r *http.Request) error { - log.Println(r.Form.Encode()) - callback := r.Form.Get("hub.callback") - topic := r.Form.Get("hub.topic") - mode := r.Form.Get("hub.mode") - - if subs, e := handler.Subscribers[topic]; e { - for i, sub := range subs { - if sub.Callback != callback { - continue - } - ourChallenge := randStringBytes(12) - - validationURL, err := url.Parse(callback) - if err != nil { - log.Println(err) - return err - } - q := validationURL.Query() - q.Add("hub.mode", mode) - q.Add("hub.topic", topic) - q.Add("hub.challenge", ourChallenge) - validationURL.RawQuery = q.Encode() - if validateURL(validationURL.String(), ourChallenge) { - subs = append(subs[:i], subs[i+1:]...) - log.Println(handler.save()) - break - } - } - w.WriteHeader(200) - fmt.Fprintf(w, "Unsubscribed\n") - } else { - http.Error(w, "Hub does not handle subscription for topic", 400) - } - return nil -} - -func (handler *subscriptionHandler) handleSubscription(w http.ResponseWriter, r *http.Request) error { - log.Printf("subscription request received: %s %#v\n", r.URL.String(), r.Form) - - callback := r.Form.Get("hub.callback") - topic := r.Form.Get("hub.topic") - secret := r.Form.Get("hub.secret") - leaseSecondsStr := r.Form.Get("hub.lease_seconds") - leaseSeconds, err := strconv.ParseInt(leaseSecondsStr, 10, 64) - if leaseSecondsStr != "" && err != nil { - http.Error(w, "hub.lease_seconds is used, but not a valid integer", 400) - log.Printf("hub.lease_seconds is used, but not a valid integer (%s)\n", leaseSecondsStr) - return err - } - - log.Printf("subscribe: received for topic=%s to callback=%s (lease=%ds)\n", topic, callback, leaseSeconds) - - if _, e := r.Form["hub.lease_seconds"]; !e { - leaseSeconds = 3600 - leaseSecondsStr = "3600" - log.Printf("subscribe: lease_seconds was empty use default %ds\n", leaseSeconds) - } - - callbackURL, err := url.Parse(callback) - if callback == "" || err != nil { - http.Error(w, "Can't parse callback url", 400) - log.Printf("Can't parse callback url: %s\n", callback) - return err - } - - topicURL, err := url.Parse(topic) - if topic == "" || err != nil { - http.Error(w, "Can't parse topic url", 400) - log.Printf("Can't parse topic url: %s\n", topic) - return err - } - - log.Println("subscribe: sending 202 header request accepted") - w.WriteHeader(202) - fmt.Fprint(w, "Accepted") - - go func() { - ourChallenge := randStringBytes(12) - - validationURL := *callbackURL - q := validationURL.Query() - q.Add("hub.mode", "subscribe") - q.Add("hub.topic", topicURL.String()) - q.Add("hub.challenge", ourChallenge) - q.Add("hub.lease_seconds", leaseSecondsStr) - if secret != "" { - q.Add("hub.verify_token", secret) - } - validationURL.RawQuery = q.Encode() - - log.Printf("subscribe: async validation with url %s\n", validationURL.String()) - - if validateURL(validationURL.String(), ourChallenge) { - log.Printf("subscribe: validation valid\n") - handler.addSubscriberCallback(topicURL.String(), Subscriber{callbackURL.String(), leaseSeconds, secret, time.Now()}) - } else { - log.Printf("subscribe: validation failed\n") - } - }() - - return nil -} - -func validateURL(url, challenge string) bool { - client := http.Client{} - - req, err := http.NewRequest(http.MethodGet, url, nil) - if err != nil { - log.Println(err) - return false - } - res, err := client.Do(req) - if err != nil { - log.Println(err) - return false - } - defer res.Body.Close() - - body, err := ioutil.ReadAll(res.Body) - if err != nil { - log.Println(err) - return false - } - - return strings.Contains(string(body), challenge) -} - -func (handler *subscriptionHandler) addSubscriberCallback(topic string, subscriber Subscriber) { - if subs, e := handler.Subscribers[topic]; e { - for i, sub := range subs { - if sub.Callback == subscriber.Callback { - handler.Subscribers[topic][i] = subscriber - if err := handler.save(); err != nil { - log.Println(err) - } - return - } - } - } - - // not found create a new subscription - handler.Subscribers[topic] = append(handler.Subscribers[topic], subscriber) -} - -func (handler *subscriptionHandler) incStat(name string) { - if v, e := handler.Stats[name]; e { - handler.Stats[name] = Stat{LastUpdate: time.Now(), Updates: v.Updates + 1} - } else { - handler.Stats[name] = Stat{LastUpdate: time.Now(), Updates: 1} - } - - handler.saveStats() -} - -func (handler *subscriptionHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) { - if r.Method == http.MethodGet { - _, _ = fmt.Fprintln(w, "WebSub hub") - if r.URL.Query().Get("debug") == "1" { - handler.incStat("http.index.debug") - enc := json.NewEncoder(w) - enc.SetIndent("", " ") - _ = enc.Encode(handler.Subscribers) - _ = enc.Encode(handler.Stats) - } - return - } - - if r.Header.Get("Content-Type") != "application/x-www-form-urlencoded" { - http.Error(w, "Bad Request", 400) - return - } - - err := r.ParseForm() - if err != nil { - http.Error(w, "Bad Request", 400) - return - } - - mode := r.Form.Get("hub.mode") - - if mode == "subscribe" { - err = handler.handleSubscription(w, r) - return - } else if mode == "unsubscribe" { - err = handler.handleUnsubscription(w, r) - return - } else if mode == "publish" { - log.Println("hub.mode=publish received") - err = handler.handlePublish(w, r) - return - } else { - http.Error(w, "Unknown hub.mode", 400) - return - } -} - -func (handler *subscriptionHandler) loadStats() error { - handler.LockStats.Lock() - defer handler.LockStats.Unlock() - - file, err := os.Open("./stats.json") - if err != nil { - if os.IsNotExist(err) { - handler.Stats = make(map[string]Stat) - return nil - } - } - defer file.Close() - dec := json.NewDecoder(file) - err = dec.Decode(&handler.Stats) - return err -} - -func (handler *subscriptionHandler) saveStats() error { - handler.LockStats.Lock() - defer handler.LockStats.Unlock() - - file, err := os.Create("./stats.json") - if err != nil { - return err - } - defer file.Close() - dec := json.NewEncoder(file) - dec.SetIndent("", " ") - - err = dec.Encode(&handler.Stats) - return err -} - -func (handler *subscriptionHandler) loadSubscriptions() error { - handler.LockSubs.Lock() - defer handler.LockSubs.Unlock() - - file, err := os.Open("./subscription.json") - if err != nil { - if os.IsNotExist(err) { - handler.Subscribers = make(map[string][]Subscriber) - return nil - } - } - defer file.Close() - dec := json.NewDecoder(file) - err = dec.Decode(&handler.Subscribers) - return err -} - - -func (handler *subscriptionHandler) saveSubscriptions() error { - handler.LockSubs.Lock() - defer handler.LockSubs.Unlock() - - file, err := os.Create("./subscription.json") - if err != nil { - return err - } - defer file.Close() - dec := json.NewEncoder(file) - dec.SetIndent("", " ") - err = dec.Encode(&handler.Subscribers) - return err -} - -func (handler *subscriptionHandler) load() error { - err := handler.loadSubscriptions() - if err != nil { - return err - } - return handler.loadStats() -} - -func (handler *subscriptionHandler) save() error { - handler.saveSubscriptions() - return handler.saveStats() -} - func main() { - var hostPort string + var hostPort,baseURL string flag.StringVar(&hostPort, "http", ":80", "host and port to listen on") + flag.StringVar(&baseURL, "baseurl", "", "baseurl that the server should response with") flag.Parse() - handler := &subscriptionHandler{} - - log.Println(handler.load()) - - log.Printf("%#v\n", handler.Subscribers) - log.Printf("%#v\n", handler.Stats) + dsn := fmt.Sprintf("postgres://%v:%v@localhost:9999/hub?sslmode=disable", "postgres", "simple") + store, err := storage.New(dsn) + if err != nil { + log.Fatal(err) + } + defer store.Close() + handler := &subscriptionHandler{store, baseURL} http.Handle("/", handler) - log.Fatal(http.ListenAndServe(hostPort, nil)) } diff --git a/cmd/hubserver/storage/postgres.go b/cmd/hubserver/storage/postgres.go new file mode 100644 index 0000000..6adb3f4 --- /dev/null +++ b/cmd/hubserver/storage/postgres.go @@ -0,0 +1,62 @@ +package storage + +import ( + "database/sql" + "fmt" + + _ "github.com/lib/pq" +) + +type postgres struct { + db *sql.DB +} + +func New(dsn string) (Service, error) { + pool, err := sql.Open("postgres", dsn) + if err != nil { + return nil, fmt.Errorf("could not open database connection: %v", err) + } + return &postgres{pool}, nil +} + +func (s *postgres) Close() error { + return s.db.Close() +} + +func (s *postgres) Subscribe(topic string, sub Subscriber) error { + _, err := s.db.Exec( + `INSERT INTO "subscribers" ("topic", "callback", "lease_seconds", "secret", "created") VALUES (?, ?, ?, ?, now())`, + topic, + sub.Callback, + sub.LeaseSeconds, + sub.Secret, + ) + return err +} + +func (s *postgres) Unsubscribe(topic, callback string) error { + _, err := s.db.Exec( + `DELETE FROM "subscribers" WHERE "topic" = ? AND "callback" = ?`, + topic, + callback, + ) + return err +} + +func (s *postgres) Subscribers(topic string) ([]Subscriber, error) { + rows, err := s.db.Query(`SELECT callback, lease_seconds, secret, created FROM "subscribers"`) + if err != nil { + return nil, err + } + + var subscribers []Subscriber + for rows.Next() { + var sub Subscriber + err := rows.Scan(&sub.Callback, &sub.LeaseSeconds, &sub.Secret, &sub.Created) + if err != nil { + return nil, err + } + subscribers = append(subscribers, sub) + } + return subscribers, nil +} diff --git a/cmd/hubserver/storage/storage.go b/cmd/hubserver/storage/storage.go new file mode 100644 index 0000000..e381d86 --- /dev/null +++ b/cmd/hubserver/storage/storage.go @@ -0,0 +1,21 @@ +package storage + +import ( + "time" + +) + +type Service interface { + Subscribe(topic string, subscriber Subscriber) error + Unsubscribe(topic, callback string) error + Subscribers(topic string) ([]Subscriber, error) + Close() error +} + +type Subscriber struct { + Callback string + LeaseSeconds int64 + Secret string + Created time.Time +} + diff --git a/db/01_tables.sql b/db/01_tables.sql new file mode 100644 index 0000000..e192495 --- /dev/null +++ b/db/01_tables.sql @@ -0,0 +1,9 @@ +create table "subscribers" +( + id int GENERATED BY DEFAULT AS IDENTITY PRIMARY KEY, + topic varchar(255) not null, + callback varchar(255) not null, + lease_seconds int not null, + secret varchar(255) not null, + created timestamp not null +);