Compare commits
1 Commits
master
...
wip-jwt-ch
| Author | SHA1 | Date | |
|---|---|---|---|
| d91b3d72b3 |
77
.drone.yml
77
.drone.yml
|
|
@ -1,74 +1,32 @@
|
||||||
---
|
---
|
||||||
kind: pipeline
|
kind: pipeline
|
||||||
type: docker
|
type: docker
|
||||||
name: build and test
|
name: build
|
||||||
|
|
||||||
workspace:
|
workspace:
|
||||||
base: /go
|
base: /go
|
||||||
path: src/p83.nl/go/ekster
|
path: src/p83.nl/go/ekster
|
||||||
|
|
||||||
trigger:
|
|
||||||
event:
|
|
||||||
- push
|
|
||||||
|
|
||||||
services:
|
|
||||||
- name: redis
|
|
||||||
image: redis:5
|
|
||||||
- name: database
|
|
||||||
image: postgres:14
|
|
||||||
environment:
|
|
||||||
POSTGRES_DB: ekster_testing
|
|
||||||
POSTGRES_USER: postgres
|
|
||||||
POSTGRES_PASSWORD: simple
|
|
||||||
POSTGRES_HOST_AUTH_METHOD: trust
|
|
||||||
|
|
||||||
steps:
|
steps:
|
||||||
- name: testing
|
- name: testing
|
||||||
image: golang:1.18-alpine
|
image: golang:1.16-alpine
|
||||||
environment:
|
environment:
|
||||||
CGO_ENABLED: 0
|
CGO_ENABLED: 0
|
||||||
GOOS: linux
|
GOOS: linux
|
||||||
GOARCH: amd64
|
GOARCH: amd64
|
||||||
commands:
|
commands:
|
||||||
- go version
|
|
||||||
- apk --no-cache add git
|
- apk --no-cache add git
|
||||||
- go get -d -t ./...
|
- go get -d -t ./...
|
||||||
- go build -buildvcs=false p83.nl/go/ekster/cmd/eksterd
|
- go install honnef.co/go/tools/cmd/staticcheck@latest
|
||||||
|
- go build p83.nl/go/ekster/cmd/eksterd
|
||||||
- go vet ./...
|
- go vet ./...
|
||||||
- go test -v ./...
|
- go test -v ./...
|
||||||
|
- staticcheck ./...
|
||||||
---
|
|
||||||
kind: pipeline
|
|
||||||
type: docker
|
|
||||||
name: move to production
|
|
||||||
|
|
||||||
workspace:
|
|
||||||
base: /go
|
|
||||||
path: src/p83.nl/go/ekster
|
|
||||||
|
|
||||||
trigger:
|
|
||||||
event:
|
|
||||||
- promote
|
|
||||||
target:
|
|
||||||
- production
|
|
||||||
|
|
||||||
steps:
|
|
||||||
- name: build
|
|
||||||
image: golang:1.18-alpine
|
|
||||||
environment:
|
|
||||||
CGO_ENABLED: 0
|
|
||||||
GOOS: linux
|
|
||||||
GOARCH: amd64
|
|
||||||
commands:
|
|
||||||
- go version
|
|
||||||
- apk --no-cache add git
|
|
||||||
- go get -d -t ./...
|
|
||||||
- go build -buildvcs=false p83.nl/go/ekster/cmd/eksterd
|
|
||||||
|
|
||||||
- name: publish-personal
|
- name: publish-personal
|
||||||
image: plugins/docker
|
image: plugins/docker
|
||||||
depends_on:
|
depends_on:
|
||||||
- build
|
- testing
|
||||||
settings:
|
settings:
|
||||||
repo: registry.stuifzandapp.com/microsub-server
|
repo: registry.stuifzandapp.com/microsub-server
|
||||||
registry: registry.stuifzandapp.com
|
registry: registry.stuifzandapp.com
|
||||||
|
|
@ -76,6 +34,24 @@ steps:
|
||||||
from_secret: docker_username
|
from_secret: docker_username
|
||||||
password:
|
password:
|
||||||
from_secret: docker_password
|
from_secret: docker_password
|
||||||
|
when:
|
||||||
|
event:
|
||||||
|
- promote
|
||||||
|
target:
|
||||||
|
- production
|
||||||
|
|
||||||
|
# - name: publish-docker
|
||||||
|
# image: plugins/docker
|
||||||
|
# depends_on:
|
||||||
|
# - testing
|
||||||
|
# settings:
|
||||||
|
# repo: pstuifzand/ekster
|
||||||
|
# tags:
|
||||||
|
# - alpine
|
||||||
|
# username:
|
||||||
|
# from_secret: docker_official_username
|
||||||
|
# password:
|
||||||
|
# from_secret: docker_official_password
|
||||||
|
|
||||||
- name: deploy
|
- name: deploy
|
||||||
image: appleboy/drone-ssh
|
image: appleboy/drone-ssh
|
||||||
|
|
@ -90,3 +66,8 @@ steps:
|
||||||
- cd /home/microsub/microsub
|
- cd /home/microsub/microsub
|
||||||
- docker-compose pull web
|
- docker-compose pull web
|
||||||
- docker-compose up -d
|
- docker-compose up -d
|
||||||
|
when:
|
||||||
|
event:
|
||||||
|
- promote
|
||||||
|
target:
|
||||||
|
- production
|
||||||
|
|
|
||||||
|
|
@ -3,4 +3,5 @@ RUN apk --no-cache add ca-certificates
|
||||||
WORKDIR /opt/micropub
|
WORKDIR /opt/micropub
|
||||||
EXPOSE 80
|
EXPOSE 80
|
||||||
COPY ./eksterd /app/
|
COPY ./eksterd /app/
|
||||||
|
COPY ./templates /app/templates
|
||||||
ENTRYPOINT ["/app/eksterd"]
|
ENTRYPOINT ["/app/eksterd"]
|
||||||
|
|
|
||||||
|
|
@ -242,7 +242,6 @@ func channelID(sub microsub.Microsub, channelNameOrID string) (string, error) {
|
||||||
// we encountered an error, so we are not sure if it worked
|
// we encountered an error, so we are not sure if it worked
|
||||||
return channelNameOrID, err
|
return channelNameOrID, err
|
||||||
}
|
}
|
||||||
|
|
||||||
for _, c := range channels {
|
for _, c := range channels {
|
||||||
if c.Name == channelNameOrID {
|
if c.Name == channelNameOrID {
|
||||||
return c.UID, nil
|
return c.UID, nil
|
||||||
|
|
|
||||||
|
|
@ -1,123 +0,0 @@
|
||||||
/*
|
|
||||||
* Ekster is a microsub server
|
|
||||||
* Copyright (c) 2022 The Ekster authors
|
|
||||||
*
|
|
||||||
* This program is free software: you can redistribute it and/or modify
|
|
||||||
* it under the terms of the GNU General Public License as published by
|
|
||||||
* the Free Software Foundation, either version 3 of the License, or
|
|
||||||
* (at your option) any later version.
|
|
||||||
*
|
|
||||||
* This program is distributed in the hope that it will be useful,
|
|
||||||
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
|
||||||
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
|
||||||
* GNU General Public License for more details.
|
|
||||||
*
|
|
||||||
* You should have received a copy of the GNU General Public License
|
|
||||||
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
|
||||||
*/
|
|
||||||
|
|
||||||
package main
|
|
||||||
|
|
||||||
import (
|
|
||||||
"database/sql"
|
|
||||||
"log"
|
|
||||||
"net/http/httptest"
|
|
||||||
"os"
|
|
||||||
"testing"
|
|
||||||
|
|
||||||
"github.com/gomodule/redigo/redis"
|
|
||||||
"github.com/stretchr/testify/assert"
|
|
||||||
"github.com/stretchr/testify/suite"
|
|
||||||
)
|
|
||||||
|
|
||||||
type DatabaseSuite struct {
|
|
||||||
suite.Suite
|
|
||||||
|
|
||||||
URL string
|
|
||||||
Database *sql.DB
|
|
||||||
|
|
||||||
RedisURL string
|
|
||||||
Redis redis.Conn
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DatabaseSuite) SetupSuite() {
|
|
||||||
db, err := sql.Open("postgres", s.URL)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
s.Database = db
|
|
||||||
|
|
||||||
conn, err := redis.Dial("tcp", s.RedisURL)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
s.Redis = conn
|
|
||||||
_, err = s.Redis.Do("SELECT", "1")
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
func (s *DatabaseSuite) TearDownSuite() {
|
|
||||||
err := s.Database.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
err = s.Redis.Close()
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
}
|
|
||||||
|
|
||||||
type databaseSuite struct {
|
|
||||||
DatabaseSuite
|
|
||||||
}
|
|
||||||
|
|
||||||
func (d *databaseSuite) TestGetChannelFromAuthorization() {
|
|
||||||
_, err := d.Database.Exec(`truncate "sources", "channels", "feeds", "subscriptions","items"`)
|
|
||||||
assert.NoError(d.T(), err, "truncate sources, channels, feeds")
|
|
||||||
row := d.Database.QueryRow(`INSERT INTO "channels" (uid, name, created_at, updated_at) VALUES ('abcdef', 'Channel', now(), now()) RETURNING "id"`)
|
|
||||||
var id int
|
|
||||||
err = row.Scan(&id)
|
|
||||||
assert.NoError(d.T(), err, "insert channel")
|
|
||||||
_, err = d.Database.Exec(`INSERT INTO "sources" (channel_id, auth_code, created_at, updated_at) VALUES ($1, '1234', now(), now())`, id)
|
|
||||||
assert.NoError(d.T(), err, "insert sources")
|
|
||||||
|
|
||||||
// source_id found
|
|
||||||
r := httptest.NewRequest("POST", "/micropub?source_id=1234", nil)
|
|
||||||
_, c, err := getChannelFromAuthorization(r, d.Redis, d.Database)
|
|
||||||
assert.NoError(d.T(), err, "channel from source_id")
|
|
||||||
assert.Equal(d.T(), "abcdef", c, "channel uid found")
|
|
||||||
|
|
||||||
// source_id not found
|
|
||||||
r = httptest.NewRequest("POST", "/micropub?source_id=1111", nil)
|
|
||||||
_, c, err = getChannelFromAuthorization(r, d.Redis, d.Database)
|
|
||||||
assert.Error(d.T(), err, "channel from authorization header")
|
|
||||||
assert.Equal(d.T(), "", c, "channel uid found")
|
|
||||||
}
|
|
||||||
|
|
||||||
func TestDatabaseSuite(t *testing.T) {
|
|
||||||
if testing.Short() {
|
|
||||||
t.Skip("Skip test for database")
|
|
||||||
}
|
|
||||||
|
|
||||||
databaseURL := os.Getenv("DATABASE_TEST_URL")
|
|
||||||
if databaseURL == "" {
|
|
||||||
databaseURL = "host=database user=postgres password=simple dbname=ekster_testing sslmode=disable"
|
|
||||||
}
|
|
||||||
databaseSuite := &databaseSuite{
|
|
||||||
DatabaseSuite{
|
|
||||||
URL: databaseURL,
|
|
||||||
RedisURL: "redis:6379",
|
|
||||||
},
|
|
||||||
}
|
|
||||||
|
|
||||||
databaseURL = "postgres://postgres@database/ekster_testing?sslmode=disable&user=postgres&password=simple"
|
|
||||||
err := runMigrations(databaseURL)
|
|
||||||
if err != nil {
|
|
||||||
log.Fatal(err)
|
|
||||||
}
|
|
||||||
|
|
||||||
suite.Run(t, databaseSuite)
|
|
||||||
}
|
|
||||||
|
|
@ -1,4 +0,0 @@
|
||||||
alter table "feeds"
|
|
||||||
drop column "tier",
|
|
||||||
drop column "unmodified",
|
|
||||||
drop column "next_fetch_at";
|
|
||||||
|
|
@ -1,4 +0,0 @@
|
||||||
alter table "feeds"
|
|
||||||
add column "tier" int default 0,
|
|
||||||
add column "unmodified" int default 0,
|
|
||||||
add column "next_fetch_at" timestamptz;
|
|
||||||
|
|
@ -1 +0,0 @@
|
||||||
DROP TABLE "sources";
|
|
||||||
|
|
@ -1,20 +0,0 @@
|
||||||
BEGIN;
|
|
||||||
CREATE OR REPLACE FUNCTION update_timestamp()
|
|
||||||
RETURNS TRIGGER AS $$
|
|
||||||
BEGIN
|
|
||||||
NEW.updated_at = now();
|
|
||||||
RETURN NEW;
|
|
||||||
END;
|
|
||||||
$$ language 'plpgsql';
|
|
||||||
COMMIT;
|
|
||||||
|
|
||||||
CREATE TABLE "sources" (
|
|
||||||
"id" int primary key generated always as identity,
|
|
||||||
"channel_id" int not null,
|
|
||||||
"auth_code" varchar(64) not null,
|
|
||||||
"created_at" timestamp DEFAULT current_timestamp,
|
|
||||||
"updated_at" timestamp DEFAULT current_timestamp
|
|
||||||
);
|
|
||||||
|
|
||||||
CREATE TRIGGER sources_update_timestamp BEFORE INSERT OR UPDATE ON "sources"
|
|
||||||
FOR EACH ROW EXECUTE PROCEDURE update_timestamp();
|
|
||||||
|
|
@ -29,9 +29,11 @@ import (
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
"github.com/cristalhq/jwt/v4"
|
||||||
"p83.nl/go/ekster/pkg/indieauth"
|
"p83.nl/go/ekster/pkg/indieauth"
|
||||||
"p83.nl/go/ekster/pkg/microsub"
|
"p83.nl/go/ekster/pkg/microsub"
|
||||||
"p83.nl/go/ekster/pkg/util"
|
"p83.nl/go/ekster/pkg/util"
|
||||||
|
|
@ -44,7 +46,7 @@ import (
|
||||||
var templates embed.FS
|
var templates embed.FS
|
||||||
|
|
||||||
type mainHandler struct {
|
type mainHandler struct {
|
||||||
Backend *memoryBackend
|
Backend microsub.Microsub
|
||||||
BaseURL string
|
BaseURL string
|
||||||
TemplateDir string
|
TemplateDir string
|
||||||
pool *redis.Pool
|
pool *redis.Pool
|
||||||
|
|
@ -111,6 +113,12 @@ type authRequest struct {
|
||||||
AccessToken string `redis:"access_token"`
|
AccessToken string `redis:"access_token"`
|
||||||
}
|
}
|
||||||
|
|
||||||
|
type micropubClaims struct {
|
||||||
|
jwt.RegisteredClaims
|
||||||
|
Channel string
|
||||||
|
Me string
|
||||||
|
}
|
||||||
|
|
||||||
func newMainHandler(backend *memoryBackend, baseURL, templateDir string, pool *redis.Pool) (*mainHandler, error) {
|
func newMainHandler(backend *memoryBackend, baseURL, templateDir string, pool *redis.Pool) (*mainHandler, error) {
|
||||||
h := &mainHandler{Backend: backend}
|
h := &mainHandler{Backend: backend}
|
||||||
|
|
||||||
|
|
@ -164,11 +172,13 @@ func loadSession(sessionVar string, conn redis.Conn) (session, error) {
|
||||||
sessionKey := "session:" + sessionVar
|
sessionKey := "session:" + sessionVar
|
||||||
data, err := redis.Values(conn.Do("HGETALL", sessionKey))
|
data, err := redis.Values(conn.Do("HGETALL", sessionKey))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
return sess, err
|
return sess, err
|
||||||
}
|
}
|
||||||
|
|
||||||
err = redis.ScanStruct(data, &sess)
|
err = redis.ScanStruct(data, &sess)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
log.Println(err)
|
||||||
return sess, err
|
return sess, err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -230,18 +240,19 @@ func verifyAuthCode(code, redirectURI, authEndpoint, clientID string) (bool, *au
|
||||||
return false, nil, fmt.Errorf("unknown content-type %q while verifying authorization_code", contentType)
|
return false, nil, fmt.Errorf("unknown content-type %q while verifying authorization_code", contentType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func isLoggedIn(backend *memoryBackend, sess *session) bool {
|
func isLoggedIn(sess *session) bool {
|
||||||
if !sess.LoggedIn {
|
if !sess.LoggedIn {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
if !backend.AuthEnabled {
|
// FIXME: Do we need this?
|
||||||
return true
|
// if !backend.AuthEnabled {
|
||||||
}
|
// return true
|
||||||
|
// }
|
||||||
if sess.Me != backend.Me {
|
//
|
||||||
return false
|
// if sess.Me != backend.Me {
|
||||||
}
|
// return false
|
||||||
|
// }
|
||||||
|
|
||||||
return true
|
return true
|
||||||
}
|
}
|
||||||
|
|
@ -251,7 +262,6 @@ func performIndieauthCallback(clientID string, r *http.Request, sess *session) (
|
||||||
if state != sess.State {
|
if state != sess.State {
|
||||||
return false, &authResponse{}, fmt.Errorf("mismatched state")
|
return false, &authResponse{}, fmt.Errorf("mismatched state")
|
||||||
}
|
}
|
||||||
|
|
||||||
code := r.Form.Get("code")
|
code := r.Form.Get("code")
|
||||||
return verifyAuthCode(code, sess.RedirectURI, sess.AuthorizationEndpoint, clientID)
|
return verifyAuthCode(code, sess.RedirectURI, sess.AuthorizationEndpoint, clientID)
|
||||||
}
|
}
|
||||||
|
|
@ -377,7 +387,7 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isLoggedIn(h.Backend, &sess) {
|
if !isLoggedIn(&sess) {
|
||||||
w.WriteHeader(401)
|
w.WriteHeader(401)
|
||||||
fmt.Fprintf(w, "Unauthorized")
|
fmt.Fprintf(w, "Unauthorized")
|
||||||
return
|
return
|
||||||
|
|
@ -402,14 +412,14 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
for _, v := range page.Channels {
|
for _, v := range page.Channels {
|
||||||
if v.UID == currentChannel {
|
if v.UID == currentChannel {
|
||||||
page.CurrentChannel = v
|
page.CurrentChannel = v
|
||||||
if setting, e := h.Backend.Settings[v.UID]; e {
|
// if setting, e := h.Backend.Settings[v.UID]; e {
|
||||||
page.CurrentSetting = setting
|
// page.CurrentSetting = setting
|
||||||
} else {
|
// } else {
|
||||||
page.CurrentSetting = channelSetting{}
|
page.CurrentSetting = channelSetting{}
|
||||||
}
|
// }
|
||||||
if page.CurrentSetting.ChannelType == "" {
|
// if page.CurrentSetting.ChannelType == "" {
|
||||||
page.CurrentSetting.ChannelType = "postgres-stream"
|
page.CurrentSetting.ChannelType = "postgres-stream"
|
||||||
}
|
// }
|
||||||
page.ExcludedTypeNames = map[string]string{
|
page.ExcludedTypeNames = map[string]string{
|
||||||
"repost": "Reposts",
|
"repost": "Reposts",
|
||||||
"like": "Likes",
|
"like": "Likes",
|
||||||
|
|
@ -448,7 +458,7 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isLoggedIn(h.Backend, &sess) {
|
if !isLoggedIn(&sess) {
|
||||||
w.WriteHeader(401)
|
w.WriteHeader(401)
|
||||||
fmt.Fprintf(w, "Unauthorized")
|
fmt.Fprintf(w, "Unauthorized")
|
||||||
return
|
return
|
||||||
|
|
@ -476,7 +486,7 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isLoggedIn(h.Backend, &sess) {
|
if !isLoggedIn(&sess) {
|
||||||
w.WriteHeader(401)
|
w.WriteHeader(401)
|
||||||
fmt.Fprintf(w, "Unauthorized")
|
fmt.Fprintf(w, "Unauthorized")
|
||||||
return
|
return
|
||||||
|
|
@ -498,9 +508,39 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
return
|
return
|
||||||
} else if r.URL.Path == "/auth" {
|
} else if r.URL.Path == "/auth" {
|
||||||
// check if we are logged in
|
// check arguments for auth
|
||||||
// TODO: if not logged in, make sure we get back here
|
query := r.URL.Query()
|
||||||
|
responseType := query.Get("response_type")
|
||||||
|
if responseType != "code" {
|
||||||
|
http.Error(w, "Unsupported response_type", 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
redirectURI := query.Get("redirect_uri")
|
||||||
|
if _, err := url.Parse(redirectURI); err != nil {
|
||||||
|
http.Error(w, "Missing redirect_uri", 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
clientID := query.Get("client_id")
|
||||||
|
if _, err := url.Parse(clientID); err != nil {
|
||||||
|
http.Error(w, "Missing client_id", 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
me := query.Get("me")
|
||||||
|
if _, err := url.Parse(me); err != nil {
|
||||||
|
http.Error(w, "Missing me", 400)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
|
||||||
|
state := query.Get("state")
|
||||||
|
scope := query.Get("scope")
|
||||||
|
if scope == "" {
|
||||||
|
scope = "create"
|
||||||
|
}
|
||||||
|
|
||||||
|
// Check if the client is logging in
|
||||||
sessionVar := getSessionCookie(w, r)
|
sessionVar := getSessionCookie(w, r)
|
||||||
|
|
||||||
sess, err := loadSession(sessionVar, conn)
|
sess, err := loadSession(sessionVar, conn)
|
||||||
|
|
@ -510,7 +550,7 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
if !isLoggedIn(h.Backend, &sess) {
|
if !isLoggedIn(&sess) {
|
||||||
sess.NextURI = r.URL.String()
|
sess.NextURI = r.URL.String()
|
||||||
saveSession(sessionVar, &sess, conn)
|
saveSession(sessionVar, &sess, conn)
|
||||||
http.Redirect(w, r, "/", http.StatusFound)
|
http.Redirect(w, r, "/", http.StatusFound)
|
||||||
|
|
@ -520,18 +560,6 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
sess.NextURI = r.URL.String()
|
sess.NextURI = r.URL.String()
|
||||||
saveSession(sessionVar, &sess, conn)
|
saveSession(sessionVar, &sess, conn)
|
||||||
|
|
||||||
query := r.URL.Query()
|
|
||||||
|
|
||||||
// responseType := query.Get("response_type") // TODO: check response_type
|
|
||||||
me := query.Get("me")
|
|
||||||
clientID := query.Get("client_id")
|
|
||||||
redirectURI := query.Get("redirect_uri")
|
|
||||||
state := query.Get("state")
|
|
||||||
scope := query.Get("scope")
|
|
||||||
if scope == "" {
|
|
||||||
scope = "create"
|
|
||||||
}
|
|
||||||
|
|
||||||
authReq := authRequest{
|
authReq := authRequest{
|
||||||
Me: me,
|
Me: me,
|
||||||
ClientID: clientID,
|
ClientID: clientID,
|
||||||
|
|
@ -585,6 +613,18 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
|
|
||||||
// redirect to endpoint
|
// redirect to endpoint
|
||||||
me := r.Form.Get("url")
|
me := r.Form.Get("url")
|
||||||
|
if !strings.HasPrefix("https://", me) {
|
||||||
|
me = fmt.Sprintf("https://%s", me)
|
||||||
|
}
|
||||||
|
meURL, err := url.Parse(me)
|
||||||
|
if err != nil {
|
||||||
|
http.Error(w, fmt.Sprintf("Not a url: %s", me), http.StatusBadRequest)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
if meURL.Path == "" {
|
||||||
|
meURL.Path = "/"
|
||||||
|
}
|
||||||
|
me = meURL.String()
|
||||||
|
|
||||||
endpoints, err := getEndpoints(me)
|
endpoints, err := getEndpoints(me)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -695,17 +735,31 @@ func (h *mainHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
fmt.Fprintf(w, "ERROR: %q", err)
|
fmt.Fprintf(w, "ERROR: %q", err)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
token := util.RandStringBytes(32)
|
|
||||||
_, err = conn.Do("HMSET", redis.Args{}.Add("token:"+token).AddFlat(&auth)...)
|
key := []byte(os.Getenv("APP_SECRET"))
|
||||||
|
signer, err := jwt.NewSignerHS(jwt.HS256, key)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Printf("could not create signer for jwt: %s", err)
|
||||||
fmt.Fprintf(w, "ERROR: %q", err)
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
|
return
|
||||||
|
}
|
||||||
|
builder := jwt.NewBuilder(signer)
|
||||||
|
token, err := builder.Build(µpubClaims{
|
||||||
|
RegisteredClaims: jwt.RegisteredClaims{
|
||||||
|
IssuedAt: jwt.NewNumericDate(time.Now()),
|
||||||
|
},
|
||||||
|
Channel: auth.Channel,
|
||||||
|
Me: auth.Me,
|
||||||
|
})
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("could not create signer for jwt: %s", err)
|
||||||
|
http.Error(w, err.Error(), http.StatusInternalServerError)
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
res := authTokenResponse{
|
res := authTokenResponse{
|
||||||
Me: auth.Me,
|
Me: auth.Me,
|
||||||
AccessToken: token,
|
AccessToken: token.String(),
|
||||||
TokenType: "Bearer",
|
TokenType: "Bearer",
|
||||||
Scope: auth.Scope,
|
Scope: auth.Scope,
|
||||||
}
|
}
|
||||||
|
|
|
||||||
95
cmd/eksterd/http_test.go
Normal file
95
cmd/eksterd/http_test.go
Normal file
|
|
@ -0,0 +1,95 @@
|
||||||
|
/*
|
||||||
|
* Ekster is a microsub server
|
||||||
|
* Copyright (c) 2021 The Ekster authors
|
||||||
|
*
|
||||||
|
* This program is free software: you can redistribute it and/or modify
|
||||||
|
* it under the terms of the GNU General Public License as published by
|
||||||
|
* the Free Software Foundation, either version 3 of the License, or
|
||||||
|
* (at your option) any later version.
|
||||||
|
*
|
||||||
|
* This program is distributed in the hope that it will be useful,
|
||||||
|
* but WITHOUT ANY WARRANTY; without even the implied warranty of
|
||||||
|
* MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE. See the
|
||||||
|
* GNU General Public License for more details.
|
||||||
|
*
|
||||||
|
* You should have received a copy of the GNU General Public License
|
||||||
|
* along with this program. If not, see <http://www.gnu.org/licenses/>.
|
||||||
|
*/
|
||||||
|
|
||||||
|
package main
|
||||||
|
|
||||||
|
import (
|
||||||
|
"io/ioutil"
|
||||||
|
"math/rand"
|
||||||
|
"net/http"
|
||||||
|
"net/http/httptest"
|
||||||
|
"net/url"
|
||||||
|
"testing"
|
||||||
|
|
||||||
|
"github.com/gomodule/redigo/redis"
|
||||||
|
"github.com/rafaeljusto/redigomock/v3"
|
||||||
|
"github.com/stretchr/testify/assert"
|
||||||
|
"p83.nl/go/ekster/pkg/server"
|
||||||
|
"p83.nl/go/ekster/pkg/util"
|
||||||
|
)
|
||||||
|
|
||||||
|
func init() {
|
||||||
|
rand.Seed(1)
|
||||||
|
}
|
||||||
|
|
||||||
|
func TestMainHandler_ServeHTTP_NoArgs(t *testing.T) {
|
||||||
|
conn := redigomock.NewConn()
|
||||||
|
pool := &redis.Pool{
|
||||||
|
// Return the same connection mock for each Get() call.
|
||||||
|
Dial: func() (redis.Conn, error) { return conn, nil },
|
||||||
|
MaxIdle: 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
h := mainHandler{Backend: &memoryBackend{AuthEnabled: true}, BaseURL: "", TemplateDir: "", pool: pool}
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/auth", nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
assert.Equal(t, 400, w.Code)
|
||||||
|
}
|
||||||
|
func TestMainHandler_ServeHTTP_Args(t *testing.T) {
|
||||||
|
conn := redigomock.NewConn()
|
||||||
|
pool := &redis.Pool{
|
||||||
|
// Return the same connection mock for each Get() call.
|
||||||
|
Dial: func() (redis.Conn, error) { return conn, nil },
|
||||||
|
MaxIdle: 10,
|
||||||
|
}
|
||||||
|
|
||||||
|
q := url.Values{}
|
||||||
|
q.Add("response_type", "code")
|
||||||
|
q.Add("client_id", "https://example.com/")
|
||||||
|
q.Add("redirect_uri", "https://example.com/callback")
|
||||||
|
q.Add("me", "https://p83.nl/")
|
||||||
|
state := util.RandStringBytes(32)
|
||||||
|
q.Add("state", state)
|
||||||
|
q.Add("scope", "create")
|
||||||
|
h := mainHandler{Backend: &server.NullBackend{}, BaseURL: "", TemplateDir: "", pool: pool}
|
||||||
|
r := httptest.NewRequest(http.MethodGet, "/auth?"+q.Encode(), nil)
|
||||||
|
w := httptest.NewRecorder()
|
||||||
|
|
||||||
|
conn.Command("HGETALL", "session:FpLSjFbcXoEFfRsW").ExpectMap(map[string]string{
|
||||||
|
"logged_in": "1",
|
||||||
|
})
|
||||||
|
conn.Command(
|
||||||
|
"HMSET", "state:XVlBzgbaiCMRAjWwhTHctcuAxhxKQFDa",
|
||||||
|
"me", "https://p83.nl/",
|
||||||
|
"client_id", "https://example.com/",
|
||||||
|
"scope", "create",
|
||||||
|
"redirect_uri", "https://example.com/callback",
|
||||||
|
"state", "XVlBzgbaiCMRAjWwhTHctcuAxhxKQFDa",
|
||||||
|
"code", "",
|
||||||
|
"channel", "",
|
||||||
|
"access_token", "",
|
||||||
|
)
|
||||||
|
|
||||||
|
h.ServeHTTP(w, r)
|
||||||
|
assert.Equal(t, 302, w.Code)
|
||||||
|
body, err := ioutil.ReadAll(w.Result().Body)
|
||||||
|
if assert.NoError(t, err) {
|
||||||
|
assert.Equal(t, "", string(body))
|
||||||
|
}
|
||||||
|
}
|
||||||
|
|
@ -135,14 +135,15 @@ func (h *hubIncomingBackend) UpdateFeed(processor ContentProcessor, subscription
|
||||||
log.Println("UpdateFeed", subscriptionID)
|
log.Println("UpdateFeed", subscriptionID)
|
||||||
|
|
||||||
db := h.database
|
db := h.database
|
||||||
|
var (
|
||||||
|
topic string
|
||||||
|
channel string
|
||||||
|
feedID string
|
||||||
|
)
|
||||||
|
|
||||||
// Process all channels that contains this feed
|
// Process all channels that contains this feed
|
||||||
rows, err := db.Query(`
|
rows, err := db.Query(
|
||||||
select topic, c.uid, f.id, c.name
|
`select topic, c.uid, f.id from subscriptions s inner join feeds f on f.url = s.topic inner join channels c on c.id = f.channel_id where s.id = $1`,
|
||||||
from subscriptions s
|
|
||||||
inner join feeds f on f.url = s.topic
|
|
||||||
inner join channels c on c.id = f.channel_id
|
|
||||||
where s.id = $1
|
|
||||||
`,
|
|
||||||
subscriptionID,
|
subscriptionID,
|
||||||
)
|
)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -150,18 +151,16 @@ where s.id = $1
|
||||||
}
|
}
|
||||||
|
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var topic, channel, feedID, channelName string
|
err = rows.Scan(&topic, &channel, &feedID)
|
||||||
|
|
||||||
err = rows.Scan(&topic, &channel, &feedID, &channelName)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Updating feed %s %q in %q (%s)\n", feedID, topic, channelName, channel)
|
log.Printf("Updating feed %s %q in %q\n", feedID, topic, channel)
|
||||||
_, err = processor.ProcessContent(channel, feedID, topic, contentType, body)
|
err = processor.ProcessContent(channel, feedID, topic, contentType, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("could not process content for channel %s: %s", channelName, err)
|
log.Printf("could not process content for channel %s: %s", channel, err)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -224,24 +223,22 @@ func (h *hubIncomingBackend) Subscribe(feed *Feed) error {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (h *hubIncomingBackend) run() error {
|
func (h *hubIncomingBackend) run() error {
|
||||||
ticker := time.NewTicker(1 * time.Minute)
|
ticker := time.NewTicker(10 * time.Minute)
|
||||||
quit := make(chan struct{})
|
quit := make(chan struct{})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-ticker.C:
|
case <-ticker.C:
|
||||||
log.Println("Getting feeds for WebSub started")
|
log.Println("Getting feeds for WebSub")
|
||||||
varWebsub.Add("runs", 1)
|
varWebsub.Add("runs", 1)
|
||||||
|
|
||||||
feeds, err := h.Feeds()
|
feeds, err := h.Feeds()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("Feeds failed:", err)
|
log.Println("Feeds failed:", err)
|
||||||
log.Println("Getting feeds for WebSub completed")
|
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Found %d feeds", len(feeds))
|
|
||||||
for _, feed := range feeds {
|
for _, feed := range feeds {
|
||||||
log.Printf("Looking at %s\n", feed.URL)
|
log.Printf("Looking at %s\n", feed.URL)
|
||||||
if feed.ResubscribeAt != nil && time.Now().After(*feed.ResubscribeAt) {
|
if feed.ResubscribeAt != nil && time.Now().After(*feed.ResubscribeAt) {
|
||||||
|
|
@ -257,8 +254,6 @@ func (h *hubIncomingBackend) run() error {
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Println("Getting feeds for WebSub completed")
|
|
||||||
case <-quit:
|
case <-quit:
|
||||||
ticker.Stop()
|
ticker.Stop()
|
||||||
return
|
return
|
||||||
|
|
|
||||||
|
|
@ -167,8 +167,8 @@ func main() {
|
||||||
// }
|
// }
|
||||||
|
|
||||||
// TODO(peter): automatically gather this information from login or otherwise
|
// TODO(peter): automatically gather this information from login or otherwise
|
||||||
databaseURL := "postgres://postgres@database/ekster?sslmode=disable&user=postgres&password=simple"
|
|
||||||
err := runMigrations(databaseURL)
|
err := runMigrations()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Fatalf("Error with migrations: %s", err)
|
log.Fatalf("Error with migrations: %s", err)
|
||||||
}
|
}
|
||||||
|
|
@ -205,12 +205,12 @@ func (l Log) Verbose() bool {
|
||||||
return false
|
return false
|
||||||
}
|
}
|
||||||
|
|
||||||
func runMigrations(databaseURL string) error {
|
func runMigrations() error {
|
||||||
d, err := iofs.New(migrations, "db/migrations")
|
d, err := iofs.New(migrations, "db/migrations")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
m, err := migrate.NewWithSourceInstance("iofs", d, databaseURL)
|
m, err := migrate.NewWithSourceInstance("iofs", d, "postgres://postgres@database/ekster?sslmode=disable&user=postgres&password=simple")
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -27,17 +27,14 @@ import (
|
||||||
"io"
|
"io"
|
||||||
"io/ioutil"
|
"io/ioutil"
|
||||||
"log"
|
"log"
|
||||||
"math"
|
|
||||||
"net/http"
|
"net/http"
|
||||||
"net/url"
|
"net/url"
|
||||||
"regexp"
|
"regexp"
|
||||||
"strconv"
|
|
||||||
"strings"
|
"strings"
|
||||||
"sync"
|
"sync"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"github.com/lib/pq"
|
"github.com/lib/pq"
|
||||||
"github.com/pkg/errors"
|
|
||||||
"p83.nl/go/ekster/pkg/auth"
|
"p83.nl/go/ekster/pkg/auth"
|
||||||
"p83.nl/go/ekster/pkg/fetch"
|
"p83.nl/go/ekster/pkg/fetch"
|
||||||
"p83.nl/go/ekster/pkg/microsub"
|
"p83.nl/go/ekster/pkg/microsub"
|
||||||
|
|
@ -105,15 +102,6 @@ type newItemMessage struct {
|
||||||
Channel string `json:"channel"`
|
Channel string `json:"channel"`
|
||||||
}
|
}
|
||||||
|
|
||||||
type feed struct {
|
|
||||||
UID string // channel
|
|
||||||
ID int
|
|
||||||
URL string
|
|
||||||
Tier int
|
|
||||||
Unmodified int
|
|
||||||
NextFetchAt time.Time
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *memoryBackend) AuthTokenAccepted(header string, r *auth.TokenResponse) (bool, error) {
|
func (b *memoryBackend) AuthTokenAccepted(header string, r *auth.TokenResponse) (bool, error) {
|
||||||
conn := b.pool.Get()
|
conn := b.pool.Get()
|
||||||
defer func() {
|
defer func() {
|
||||||
|
|
@ -156,9 +144,6 @@ GROUP BY c.id;
|
||||||
}})
|
}})
|
||||||
}
|
}
|
||||||
|
|
||||||
util.StablePartition(channels, 0, len(channels), func(i int) bool {
|
|
||||||
return channels[i].Unread.HasUnread()
|
|
||||||
})
|
|
||||||
return channels, nil
|
return channels, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -190,7 +175,7 @@ func (b *memoryBackend) ChannelsCreate(name string) (microsub.Channel, error) {
|
||||||
for {
|
for {
|
||||||
varMicrosub.Add("ChannelsCreate.RandStringBytes", 1)
|
varMicrosub.Add("ChannelsCreate.RandStringBytes", 1)
|
||||||
channel.UID = util.RandStringBytes(24)
|
channel.UID = util.RandStringBytes(24)
|
||||||
result, err := b.database.Exec(`insert into "channels" ("uid", "name", "created_at") values ($1, $2, DEFAULT)`, channel.UID, channel.Name)
|
result, err := b.database.Exec(`insert into "channels" ("uid", "name", "created_at") values($1, $2, DEFAULT)`, channel.UID, channel.Name)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println("channels insert", err)
|
log.Println("channels insert", err)
|
||||||
if !shouldRetryWithNewUID(err, try) {
|
if !shouldRetryWithNewUID(err, try) {
|
||||||
|
|
@ -199,7 +184,7 @@ func (b *memoryBackend) ChannelsCreate(name string) (microsub.Channel, error) {
|
||||||
try++
|
try++
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
if n, err := result.RowsAffected(); err == nil {
|
if n, err := result.RowsAffected(); err != nil {
|
||||||
if n > 0 {
|
if n > 0 {
|
||||||
b.broker.Notifier <- sse.Message{Event: "new channel", Object: channelMessage{1, channel}}
|
b.broker.Notifier <- sse.Message{Event: "new channel", Object: channelMessage{1, channel}}
|
||||||
}
|
}
|
||||||
|
|
@ -234,22 +219,15 @@ func (b *memoryBackend) ChannelsDelete(uid string) error {
|
||||||
b.broker.Notifier <- sse.Message{Event: "delete channel", Object: channelDeletedMessage{1, uid}}
|
b.broker.Notifier <- sse.Message{Event: "delete channel", Object: channelDeletedMessage{1, uid}}
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
func (b *memoryBackend) updateFeed(feed feed) error {
|
|
||||||
_, err := b.database.Exec(`
|
type feed struct {
|
||||||
UPDATE "feeds"
|
UID string // channel
|
||||||
SET "tier" = $2, "unmodified" = $3, "next_fetch_at" = $4
|
ID int
|
||||||
WHERE "id" = $1
|
URL string
|
||||||
`, feed.ID, feed.Tier, feed.Unmodified, feed.NextFetchAt)
|
|
||||||
return err
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *memoryBackend) getFeeds() ([]feed, error) {
|
func (b *memoryBackend) getFeeds() ([]feed, error) {
|
||||||
rows, err := b.database.Query(`
|
rows, err := b.database.Query(`SELECT "f"."id", "f"."url", "c"."uid" FROM "feeds" AS "f" INNER JOIN public.channels c on c.id = f.channel_id`)
|
||||||
SELECT "f"."id", "f"."url", "c"."uid", "f"."tier","f"."unmodified","f"."next_fetch_at"
|
|
||||||
FROM "feeds" AS "f"
|
|
||||||
INNER JOIN public.channels c ON c.id = f.channel_id
|
|
||||||
WHERE next_fetch_at IS NULL OR next_fetch_at < now()
|
|
||||||
`)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return nil, err
|
return nil, err
|
||||||
}
|
}
|
||||||
|
|
@ -258,49 +236,29 @@ WHERE next_fetch_at IS NULL OR next_fetch_at < now()
|
||||||
for rows.Next() {
|
for rows.Next() {
|
||||||
var feedID int
|
var feedID int
|
||||||
var feedURL, UID string
|
var feedURL, UID string
|
||||||
var tier, unmodified int
|
|
||||||
var nextFetchAt sql.NullTime
|
|
||||||
|
|
||||||
err = rows.Scan(&feedID, &feedURL, &UID, &tier, &unmodified, &nextFetchAt)
|
err = rows.Scan(&feedID, &feedURL, &UID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("while scanning feeds: %s", err)
|
log.Printf("while scanning feeds: %s", err)
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
var fetchTime time.Time
|
feeds = append(feeds, feed{UID, feedID, feedURL})
|
||||||
if nextFetchAt.Valid {
|
|
||||||
fetchTime = nextFetchAt.Time
|
|
||||||
} else {
|
|
||||||
fetchTime = time.Now()
|
|
||||||
}
|
|
||||||
|
|
||||||
feeds = append(
|
|
||||||
feeds,
|
|
||||||
feed{
|
|
||||||
UID: UID,
|
|
||||||
ID: feedID,
|
|
||||||
URL: feedURL,
|
|
||||||
Tier: tier,
|
|
||||||
Unmodified: unmodified,
|
|
||||||
NextFetchAt: fetchTime,
|
|
||||||
},
|
|
||||||
)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
return feeds, nil
|
return feeds, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *memoryBackend) run() {
|
func (b *memoryBackend) run() {
|
||||||
b.ticker = time.NewTicker(1 * time.Minute)
|
b.ticker = time.NewTicker(10 * time.Minute)
|
||||||
b.quit = make(chan struct{})
|
b.quit = make(chan struct{})
|
||||||
|
|
||||||
go func() {
|
go func() {
|
||||||
b.RefreshFeeds()
|
|
||||||
|
|
||||||
for {
|
for {
|
||||||
select {
|
select {
|
||||||
case <-b.ticker.C:
|
case <-b.ticker.C:
|
||||||
b.RefreshFeeds()
|
b.RefreshFeeds()
|
||||||
|
|
||||||
case <-b.quit:
|
case <-b.quit:
|
||||||
b.ticker.Stop()
|
b.ticker.Stop()
|
||||||
return
|
return
|
||||||
|
|
@ -310,89 +268,43 @@ func (b *memoryBackend) run() {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *memoryBackend) RefreshFeeds() {
|
func (b *memoryBackend) RefreshFeeds() {
|
||||||
log.Println("Feed update process started")
|
|
||||||
defer log.Println("Feed update process completed")
|
|
||||||
|
|
||||||
feeds, err := b.getFeeds()
|
feeds, err := b.getFeeds()
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
|
|
||||||
log.Printf("Found %d feeds", len(feeds))
|
|
||||||
|
|
||||||
count := 0
|
count := 0
|
||||||
|
|
||||||
for _, feed := range feeds {
|
for _, feed := range feeds {
|
||||||
log.Println("Processing", feed.URL)
|
feedURL := feed.URL
|
||||||
err := b.refreshFeed(feed)
|
feedID := feed.ID
|
||||||
|
uid := feed.UID
|
||||||
|
log.Println("Processing", feedURL)
|
||||||
|
resp, err := b.Fetch3(uid, feedURL)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
b.addNotification("Error while fetching feed", feed, err)
|
log.Printf("Error while Fetch3 of %s: %v\n", feedURL, err)
|
||||||
|
b.addNotification("Error while fetching feed", feedURL, err)
|
||||||
|
count++
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
err = b.ProcessContent(uid, fmt.Sprintf("%d", feedID), feedURL, resp.Header.Get("Content-Type"), resp.Body)
|
||||||
|
if err != nil {
|
||||||
|
log.Printf("Error while processing content for %s: %v\n", feedURL, err)
|
||||||
|
b.addNotification("Error while processing feed", feedURL, err)
|
||||||
count++
|
count++
|
||||||
|
continue
|
||||||
|
}
|
||||||
|
_ = resp.Body.Close()
|
||||||
}
|
}
|
||||||
|
|
||||||
if count > 0 {
|
if count > 0 {
|
||||||
_ = b.updateChannelUnreadCount("notifications")
|
_ = b.updateChannelUnreadCount("notifications")
|
||||||
}
|
}
|
||||||
log.Printf("Processed %d feeds", count)
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *memoryBackend) refreshFeed(feed feed) error {
|
func (b *memoryBackend) addNotification(name string, feedURL string, err error) {
|
||||||
resp, err := b.Fetch3(feed.UID, feed.URL)
|
err = b.channelAddItem("notifications", microsub.Item{
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("while Fetch3 of %s: %w", feed.URL, err)
|
|
||||||
}
|
|
||||||
defer resp.Body.Close()
|
|
||||||
|
|
||||||
changed, err := b.ProcessContent(feed.UID, fmt.Sprintf("%d", feed.ID), feed.URL, resp.Header.Get("Content-Type"), resp.Body)
|
|
||||||
if err != nil {
|
|
||||||
return fmt.Errorf("in ProcessContent of %s: %w", feed.URL, err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if changed {
|
|
||||||
feed.Tier--
|
|
||||||
} else {
|
|
||||||
feed.Unmodified++
|
|
||||||
}
|
|
||||||
|
|
||||||
if feed.Unmodified >= 2 {
|
|
||||||
feed.Tier++
|
|
||||||
feed.Unmodified = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
if feed.Tier > 10 {
|
|
||||||
feed.Tier = 10
|
|
||||||
}
|
|
||||||
|
|
||||||
if feed.Tier < 0 {
|
|
||||||
feed.Tier = 0
|
|
||||||
}
|
|
||||||
|
|
||||||
minutes := time.Duration(math.Ceil(math.Exp2(float64(feed.Tier))))
|
|
||||||
|
|
||||||
feed.NextFetchAt = time.Now().Add(minutes * time.Minute)
|
|
||||||
|
|
||||||
log.Printf("Next Fetch in %d minutes at %v", minutes, feed.NextFetchAt.Format(time.RFC3339))
|
|
||||||
|
|
||||||
err = b.updateFeed(feed)
|
|
||||||
if err != nil {
|
|
||||||
log.Printf("Error: while updating feed %v: %v", feed, err)
|
|
||||||
// don't return error, because it becomes a notification
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
return nil
|
|
||||||
}
|
|
||||||
|
|
||||||
func (b *memoryBackend) addNotification(name string, feed feed, err error) {
|
|
||||||
_, err = b.channelAddItem("notifications", microsub.Item{
|
|
||||||
Type: "entry",
|
Type: "entry",
|
||||||
Source: µsub.Source{
|
|
||||||
ID: strconv.Itoa(feed.ID),
|
|
||||||
URL: feed.URL,
|
|
||||||
Name: feed.URL,
|
|
||||||
},
|
|
||||||
Name: name,
|
Name: name,
|
||||||
Content: µsub.Content{
|
Content: µsub.Content{
|
||||||
Text: fmt.Sprintf("ERROR: while updating feed: %s", err),
|
Text: fmt.Sprintf("ERROR: while updating feed: %s", err),
|
||||||
|
|
@ -413,12 +325,9 @@ func (b *memoryBackend) TimelineGet(before, after, channel string) (microsub.Tim
|
||||||
return microsub.Timeline{Items: []microsub.Item{}}, err
|
return microsub.Timeline{Items: []microsub.Item{}}, err
|
||||||
}
|
}
|
||||||
|
|
||||||
timelineBackend, err := b.getTimeline(channel)
|
timelineBackend := b.getTimeline(channel)
|
||||||
if err != nil {
|
|
||||||
return microsub.Timeline{}, err
|
|
||||||
}
|
|
||||||
|
|
||||||
// _ = b.updateChannelUnreadCount(channel)
|
_ = b.updateChannelUnreadCount(channel)
|
||||||
|
|
||||||
return timelineBackend.Items(before, after)
|
return timelineBackend.Items(before, after)
|
||||||
}
|
}
|
||||||
|
|
@ -445,7 +354,7 @@ func (b *memoryBackend) FollowGetList(uid string) ([]microsub.Feed, error) {
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *memoryBackend) FollowURL(uid string, url string) (microsub.Feed, error) {
|
func (b *memoryBackend) FollowURL(uid string, url string) (microsub.Feed, error) {
|
||||||
subFeed := microsub.Feed{Type: "feed", URL: url}
|
feed := microsub.Feed{Type: "feed", URL: url}
|
||||||
|
|
||||||
var channelID int
|
var channelID int
|
||||||
err := b.database.QueryRow(`SELECT "id" FROM "channels" WHERE "uid" = $1`, uid).Scan(&channelID)
|
err := b.database.QueryRow(`SELECT "id" FROM "channels" WHERE "uid" = $1`, uid).Scan(&channelID)
|
||||||
|
|
@ -458,36 +367,28 @@ func (b *memoryBackend) FollowURL(uid string, url string) (microsub.Feed, error)
|
||||||
|
|
||||||
var feedID int
|
var feedID int
|
||||||
err = b.database.QueryRow(
|
err = b.database.QueryRow(
|
||||||
`INSERT INTO "feeds" ("channel_id", "url", "tier", "unmodified", "next_fetch_at") VALUES ($1, $2, 1, 0, now()) RETURNING "id"`,
|
`INSERT INTO "feeds" ("channel_id", "url") VALUES ($1, $2) RETURNING "id"`,
|
||||||
channelID,
|
channelID,
|
||||||
subFeed.URL,
|
feed.URL,
|
||||||
).Scan(&feedID)
|
).Scan(&feedID)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return subFeed, err
|
return feed, err
|
||||||
}
|
}
|
||||||
|
|
||||||
var newFeed = feed{
|
resp, err := b.Fetch3(uid, feed.URL)
|
||||||
ID: feedID,
|
|
||||||
UID: uid,
|
|
||||||
URL: url,
|
|
||||||
Tier: 1,
|
|
||||||
Unmodified: 0,
|
|
||||||
NextFetchAt: time.Now(),
|
|
||||||
}
|
|
||||||
resp, err := b.Fetch3(uid, subFeed.URL)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
b.addNotification("Error while fetching feed", newFeed, err)
|
b.addNotification("Error while fetching feed", feed.URL, err)
|
||||||
_ = b.updateChannelUnreadCount("notifications")
|
_ = b.updateChannelUnreadCount("notifications")
|
||||||
return subFeed, err
|
return feed, err
|
||||||
}
|
}
|
||||||
defer resp.Body.Close()
|
defer resp.Body.Close()
|
||||||
|
|
||||||
_, _ = b.ProcessContent(uid, fmt.Sprintf("%d", feedID), subFeed.URL, resp.Header.Get("Content-Type"), resp.Body)
|
_ = b.ProcessContent(uid, fmt.Sprintf("%d", feedID), feed.URL, resp.Header.Get("Content-Type"), resp.Body)
|
||||||
|
|
||||||
_, _ = b.hubBackend.CreateFeed(url)
|
_, _ = b.hubBackend.CreateFeed(url)
|
||||||
|
|
||||||
return subFeed, nil
|
return feed, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *memoryBackend) UnfollowURL(uid string, url string) error {
|
func (b *memoryBackend) UnfollowURL(uid string, url string) error {
|
||||||
|
|
@ -624,16 +525,15 @@ func (b *memoryBackend) PreviewURL(previewURL string) (microsub.Timeline, error)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *memoryBackend) MarkRead(channel string, uids []string) error {
|
func (b *memoryBackend) MarkRead(channel string, uids []string) error {
|
||||||
tl, err := b.getTimeline(channel)
|
tl := b.getTimeline(channel)
|
||||||
|
err := tl.MarkRead(uids)
|
||||||
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
if err = tl.MarkRead(uids); err != nil {
|
err = b.updateChannelUnreadCount(channel)
|
||||||
return err
|
if err != nil {
|
||||||
}
|
|
||||||
|
|
||||||
if err = b.updateChannelUnreadCount(channel); err != nil {
|
|
||||||
return err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
@ -684,35 +584,31 @@ func ProcessSourcedItems(fetcher fetch.Fetcher, fetchURL, contentType string, bo
|
||||||
|
|
||||||
// ContentProcessor processes content for a channel and feed
|
// ContentProcessor processes content for a channel and feed
|
||||||
type ContentProcessor interface {
|
type ContentProcessor interface {
|
||||||
ProcessContent(channel, feedID, fetchURL, contentType string, body io.Reader) (bool, error)
|
ProcessContent(channel, feedID, fetchURL, contentType string, body io.Reader) error
|
||||||
}
|
}
|
||||||
|
|
||||||
// ProcessContent processes content of a feed, returns if the feed has changed or not
|
func (b *memoryBackend) ProcessContent(channel, feedID, fetchURL, contentType string, body io.Reader) error {
|
||||||
func (b *memoryBackend) ProcessContent(channel, feedID, fetchURL, contentType string, body io.Reader) (bool, error) {
|
|
||||||
cachingFetch := WithCaching(b.pool, fetch.FetcherFunc(Fetch2))
|
cachingFetch := WithCaching(b.pool, fetch.FetcherFunc(Fetch2))
|
||||||
|
|
||||||
items, err := ProcessSourcedItems(cachingFetch, fetchURL, contentType, body)
|
items, err := ProcessSourcedItems(cachingFetch, fetchURL, contentType, body)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return false, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
changed := false
|
|
||||||
|
|
||||||
for _, item := range items {
|
for _, item := range items {
|
||||||
item.Source.ID = feedID
|
item.Source.ID = feedID
|
||||||
added, err := b.channelAddItemWithMatcher(channel, item)
|
err = b.channelAddItemWithMatcher(channel, item)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("ERROR: (feedID=%s) %s\n", feedID, err)
|
log.Printf("ERROR: (feedID=%s) %s\n", feedID, err)
|
||||||
}
|
}
|
||||||
changed = changed || added
|
|
||||||
}
|
}
|
||||||
|
|
||||||
err = b.updateChannelUnreadCount(channel)
|
err = b.updateChannelUnreadCount(channel)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return changed, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
return changed, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// Fetch3 fills stuff
|
// Fetch3 fills stuff
|
||||||
|
|
@ -721,12 +617,17 @@ func (b *memoryBackend) Fetch3(channel, fetchURL string) (*http.Response, error)
|
||||||
return Fetch2(fetchURL)
|
return Fetch2(fetchURL)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *memoryBackend) channelAddItemWithMatcher(channel string, item microsub.Item) (bool, error) {
|
func (b *memoryBackend) channelAddItemWithMatcher(channel string, item microsub.Item) error {
|
||||||
// an item is posted
|
// an item is posted
|
||||||
// check for all channels as channel
|
// check for all channels as channel
|
||||||
// if regex matches item
|
// if regex matches item
|
||||||
// - add item to channel
|
// - add item to channel
|
||||||
|
|
||||||
|
err := addToSearch(item, channel)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("addToSearch in channelAddItemWithMatcher: %v", err)
|
||||||
|
}
|
||||||
|
|
||||||
var updatedChannels []string
|
var updatedChannels []string
|
||||||
|
|
||||||
b.lock.RLock()
|
b.lock.RLock()
|
||||||
|
|
@ -739,23 +640,23 @@ func (b *memoryBackend) channelAddItemWithMatcher(channel string, item microsub.
|
||||||
switch v {
|
switch v {
|
||||||
case "repost":
|
case "repost":
|
||||||
if len(item.RepostOf) > 0 {
|
if len(item.RepostOf) > 0 {
|
||||||
return false, nil
|
return nil
|
||||||
}
|
}
|
||||||
case "like":
|
case "like":
|
||||||
if len(item.LikeOf) > 0 {
|
if len(item.LikeOf) > 0 {
|
||||||
return false, nil
|
return nil
|
||||||
}
|
}
|
||||||
case "bookmark":
|
case "bookmark":
|
||||||
if len(item.BookmarkOf) > 0 {
|
if len(item.BookmarkOf) > 0 {
|
||||||
return false, nil
|
return nil
|
||||||
}
|
}
|
||||||
case "reply":
|
case "reply":
|
||||||
if len(item.InReplyTo) > 0 {
|
if len(item.InReplyTo) > 0 {
|
||||||
return false, nil
|
return nil
|
||||||
}
|
}
|
||||||
case "checkin":
|
case "checkin":
|
||||||
if item.Checkin != nil {
|
if item.Checkin != nil {
|
||||||
return false, nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
@ -764,27 +665,19 @@ func (b *memoryBackend) channelAddItemWithMatcher(channel string, item microsub.
|
||||||
re, err := regexp.Compile(setting.IncludeRegex)
|
re, err := regexp.Compile(setting.IncludeRegex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("error in regexp: %q, %s\n", setting.IncludeRegex, err)
|
log.Printf("error in regexp: %q, %s\n", setting.IncludeRegex, err)
|
||||||
return false, nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
||||||
if matchItem(item, re) {
|
if matchItem(item, re) {
|
||||||
log.Printf("Included %#v\n", item)
|
log.Printf("Included %#v\n", item)
|
||||||
added, err := b.channelAddItem(channelKey, item)
|
err := b.channelAddItem(channelKey, item)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
continue
|
continue
|
||||||
}
|
}
|
||||||
|
|
||||||
err = addToSearch(item, channel)
|
|
||||||
if err != nil {
|
|
||||||
return added, fmt.Errorf("addToSearch in channelAddItemWithMatcher: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
if added {
|
|
||||||
updatedChannels = append(updatedChannels, channelKey)
|
updatedChannels = append(updatedChannels, channelKey)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
}
|
|
||||||
|
|
||||||
// Update all channels that have added items, because of the include matching
|
// Update all channels that have added items, because of the include matching
|
||||||
for _, value := range updatedChannels {
|
for _, value := range updatedChannels {
|
||||||
|
|
@ -804,26 +697,15 @@ func (b *memoryBackend) channelAddItemWithMatcher(channel string, item microsub.
|
||||||
excludeRegex, err := regexp.Compile(setting.ExcludeRegex)
|
excludeRegex, err := regexp.Compile(setting.ExcludeRegex)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("error in regexp: %q\n", excludeRegex)
|
log.Printf("error in regexp: %q\n", excludeRegex)
|
||||||
return false, nil
|
return nil
|
||||||
}
|
}
|
||||||
if matchItem(item, excludeRegex) {
|
if matchItem(item, excludeRegex) {
|
||||||
log.Printf("Excluded %#v\n", item)
|
log.Printf("Excluded %#v\n", item)
|
||||||
return false, nil
|
return nil
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
|
|
||||||
added, err := b.channelAddItem(channel, item)
|
return b.channelAddItem(channel, item)
|
||||||
|
|
||||||
if err != nil {
|
|
||||||
return added, err
|
|
||||||
}
|
|
||||||
|
|
||||||
err = addToSearch(item, channel)
|
|
||||||
if err != nil {
|
|
||||||
return added, fmt.Errorf("addToSearch in channelAddItemWithMatcher: %v", err)
|
|
||||||
}
|
|
||||||
|
|
||||||
return added, nil
|
|
||||||
}
|
}
|
||||||
|
|
||||||
func matchItem(item microsub.Item, re *regexp.Regexp) bool {
|
func matchItem(item microsub.Item, re *regexp.Regexp) bool {
|
||||||
|
|
@ -852,15 +734,11 @@ func matchItemText(item microsub.Item, re *regexp.Regexp) bool {
|
||||||
return re.MatchString(item.Name)
|
return re.MatchString(item.Name)
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *memoryBackend) channelAddItem(channel string, item microsub.Item) (bool, error) {
|
func (b *memoryBackend) channelAddItem(channel string, item microsub.Item) error {
|
||||||
timelineBackend, err := b.getTimeline(channel)
|
timelineBackend := b.getTimeline(channel)
|
||||||
if err != nil {
|
|
||||||
return false, err
|
|
||||||
}
|
|
||||||
|
|
||||||
added, err := timelineBackend.AddItem(item)
|
added, err := timelineBackend.AddItem(item)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return added, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// Sent message to Server-Sent-Events
|
// Sent message to Server-Sent-Events
|
||||||
|
|
@ -868,33 +746,23 @@ func (b *memoryBackend) channelAddItem(channel string, item microsub.Item) (bool
|
||||||
b.broker.Notifier <- sse.Message{Event: "new item", Object: newItemMessage{item, channel}}
|
b.broker.Notifier <- sse.Message{Event: "new item", Object: newItemMessage{item, channel}}
|
||||||
}
|
}
|
||||||
|
|
||||||
return added, err
|
return err
|
||||||
}
|
}
|
||||||
|
|
||||||
// ErrNotUpdated is used when the unread count is not updated
|
|
||||||
var ErrNotUpdated = errors.New("timeline unread count not updated")
|
|
||||||
|
|
||||||
// ErrNotFound is used when the timeline is not found
|
|
||||||
var ErrNotFound = errors.New("timeline not found")
|
|
||||||
|
|
||||||
func (b *memoryBackend) updateChannelUnreadCount(channel string) error {
|
func (b *memoryBackend) updateChannelUnreadCount(channel string) error {
|
||||||
tl, err := b.getTimeline(channel)
|
// tl := b.getTimeline(channel)
|
||||||
if err != nil {
|
// unread, err := tl.Count()
|
||||||
return err
|
// if err != nil {
|
||||||
}
|
// return err
|
||||||
|
// }
|
||||||
unread, err := tl.Count()
|
//
|
||||||
if err != nil {
|
// currentCount := c.Unread.UnreadCount
|
||||||
return ErrNotUpdated
|
// c.Unread = microsub.Unread{Type: microsub.UnreadCount, UnreadCount: unread}
|
||||||
}
|
//
|
||||||
|
// // Sent message to Server-Sent-Events
|
||||||
var c = microsub.Channel{
|
// if currentCount != unread {
|
||||||
UID: channel,
|
// b.broker.Notifier <- sse.Message{Event: "new item in channel", Object: c}
|
||||||
Unread: microsub.Unread{Type: microsub.UnreadCount, UnreadCount: unread},
|
// }
|
||||||
}
|
|
||||||
|
|
||||||
// Sent message to Server-Sent-Events
|
|
||||||
b.broker.Notifier <- sse.Message{Event: "new item in channel", Object: c}
|
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
}
|
}
|
||||||
|
|
@ -974,12 +842,15 @@ func Fetch2(fetchURL string) (*http.Response, error) {
|
||||||
return resp, err
|
return resp, err
|
||||||
}
|
}
|
||||||
|
|
||||||
func (b *memoryBackend) getTimeline(channel string) (timeline.Backend, error) {
|
func (b *memoryBackend) getTimeline(channel string) timeline.Backend {
|
||||||
// Set a default timeline type if not set
|
// Set a default timeline type if not set
|
||||||
timelineType := "postgres-stream"
|
timelineType := "postgres-stream"
|
||||||
|
// if setting, ok := b.Settings[channel]; ok && setting.ChannelType != "" {
|
||||||
|
// timelineType = setting.ChannelType
|
||||||
|
// }
|
||||||
tl := timeline.Create(channel, timelineType, b.pool, b.database)
|
tl := timeline.Create(channel, timelineType, b.pool, b.database)
|
||||||
if tl == nil {
|
if tl == nil {
|
||||||
return tl, fmt.Errorf("timeline id %q: %w", channel, ErrNotFound)
|
log.Printf("no timeline found with name %q and type %q", channel, timelineType)
|
||||||
}
|
}
|
||||||
return tl, nil
|
return tl
|
||||||
}
|
}
|
||||||
|
|
|
||||||
|
|
@ -20,11 +20,11 @@ package main
|
||||||
|
|
||||||
import (
|
import (
|
||||||
"crypto/sha1"
|
"crypto/sha1"
|
||||||
"database/sql"
|
|
||||||
"encoding/json"
|
"encoding/json"
|
||||||
"fmt"
|
"fmt"
|
||||||
"log"
|
"log"
|
||||||
"net/http"
|
"net/http"
|
||||||
|
"os"
|
||||||
"strings"
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
|
|
@ -34,6 +34,8 @@ import (
|
||||||
"github.com/gomodule/redigo/redis"
|
"github.com/gomodule/redigo/redis"
|
||||||
"github.com/pkg/errors"
|
"github.com/pkg/errors"
|
||||||
"willnorris.com/go/microformats"
|
"willnorris.com/go/microformats"
|
||||||
|
|
||||||
|
"github.com/cristalhq/jwt/v4"
|
||||||
)
|
)
|
||||||
|
|
||||||
type micropubHandler struct {
|
type micropubHandler struct {
|
||||||
|
|
@ -66,7 +68,7 @@ func (h *micropubHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
if r.Method == http.MethodPost {
|
if r.Method == http.MethodPost {
|
||||||
var channel string
|
var channel string
|
||||||
|
|
||||||
sourceID, channel, err := getChannelFromAuthorization(r, conn, h.Backend.database)
|
channel, err = getChannelFromAuthorization(r, conn)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
http.Error(w, "unauthorized", http.StatusUnauthorized)
|
||||||
|
|
@ -100,12 +102,7 @@ func (h *micropubHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
}
|
}
|
||||||
item.ID = newID
|
item.ID = newID
|
||||||
|
|
||||||
item.Source = µsub.Source{
|
err = h.Backend.channelAddItemWithMatcher(channel, *item)
|
||||||
ID: fmt.Sprintf("micropub:%d", sourceID),
|
|
||||||
Name: fmt.Sprintf("Source %d", sourceID),
|
|
||||||
}
|
|
||||||
|
|
||||||
_, err = h.Backend.channelAddItemWithMatcher(channel, *item)
|
|
||||||
if err != nil {
|
if err != nil {
|
||||||
log.Printf("could not add item to channel %s: %v", channel, err)
|
log.Printf("could not add item to channel %s: %v", channel, err)
|
||||||
}
|
}
|
||||||
|
|
@ -171,23 +168,16 @@ func parseIncomingItem(r *http.Request) (*microsub.Item, error) {
|
||||||
return nil, fmt.Errorf("content-type %q is not supported", contentType)
|
return nil, fmt.Errorf("content-type %q is not supported", contentType)
|
||||||
}
|
}
|
||||||
|
|
||||||
func getChannelFromAuthorization(r *http.Request, conn redis.Conn, database *sql.DB) (int, string, error) {
|
func getChannelFromAuthorization(r *http.Request, conn redis.Conn) (string, error) {
|
||||||
// backward compatible
|
// backward compatible
|
||||||
sourceID := r.URL.Query().Get("source_id")
|
sourceID := r.URL.Query().Get("source_id")
|
||||||
if sourceID != "" {
|
if sourceID != "" {
|
||||||
row := database.QueryRow(`
|
channel, err := redis.String(conn.Do("HGET", "sources", sourceID))
|
||||||
SELECT s.id as source_id, c.uid
|
if err != nil {
|
||||||
FROM "sources" AS "s"
|
return "", errors.Wrapf(err, "could not get channel for sourceID: %s", sourceID)
|
||||||
INNER JOIN "channels" AS "c" ON s.channel_id = c.id
|
|
||||||
WHERE "auth_code" = $1
|
|
||||||
`, sourceID)
|
|
||||||
|
|
||||||
var channel string
|
|
||||||
var sourceID int
|
|
||||||
if err := row.Scan(&sourceID, &channel); err == sql.ErrNoRows {
|
|
||||||
return 0, "", errors.New("channel not found")
|
|
||||||
}
|
}
|
||||||
return sourceID, channel, nil
|
|
||||||
|
return channel, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
// full micropub with indieauth
|
// full micropub with indieauth
|
||||||
|
|
@ -196,11 +186,41 @@ WHERE "auth_code" = $1
|
||||||
token := authHeader[7:]
|
token := authHeader[7:]
|
||||||
channel, err := redis.String(conn.Do("HGET", "token:"+token, "channel"))
|
channel, err := redis.String(conn.Do("HGET", "token:"+token, "channel"))
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return 0, "", errors.Wrap(err, "could not get channel for token")
|
_, err := verifyJWT(token)
|
||||||
|
if err != nil {
|
||||||
|
return "", err
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, channel, nil
|
return "", errors.Wrap(err, "could not get channel for token")
|
||||||
}
|
}
|
||||||
|
|
||||||
return 0, "", fmt.Errorf("could not get channel from authorization")
|
return channel, nil
|
||||||
|
}
|
||||||
|
|
||||||
|
return "", fmt.Errorf("could not get channel from authorization")
|
||||||
|
}
|
||||||
|
|
||||||
|
func verifyJWT(token string) (bool, error) {
|
||||||
|
key := []byte(os.Getenv("APP_SECRET"))
|
||||||
|
verifier, err := jwt.NewVerifierHS(jwt.HS256, key)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("could not create verifier for jwt: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
newToken, err := jwt.Parse([]byte(token), verifier)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("could not parse jwt: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
var newClaims jwt.RegisteredClaims
|
||||||
|
err = json.Unmarshal(newToken.Claims(), &newClaims)
|
||||||
|
if err != nil {
|
||||||
|
return false, fmt.Errorf("could not parse jwt claims: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
|
if !newClaims.IsValidAt(time.Now()) {
|
||||||
|
return false, fmt.Errorf("jwt is not valid now")
|
||||||
|
}
|
||||||
|
|
||||||
|
return true, nil
|
||||||
}
|
}
|
||||||
|
|
|
||||||
5
go.mod
5
go.mod
|
|
@ -5,14 +5,15 @@ go 1.16
|
||||||
require (
|
require (
|
||||||
github.com/axgle/mahonia v0.0.0-20180208002826-3358181d7394
|
github.com/axgle/mahonia v0.0.0-20180208002826-3358181d7394
|
||||||
github.com/blevesearch/bleve/v2 v2.0.3
|
github.com/blevesearch/bleve/v2 v2.0.3
|
||||||
|
github.com/cristalhq/jwt/v4 v4.0.0-beta // indirect
|
||||||
github.com/davecgh/go-spew v1.1.1 // indirect
|
github.com/davecgh/go-spew v1.1.1 // indirect
|
||||||
github.com/gilliek/go-opml v1.0.0
|
github.com/gilliek/go-opml v1.0.0
|
||||||
github.com/golang-migrate/migrate/v4 v4.15.1
|
github.com/golang-migrate/migrate/v4 v4.15.1
|
||||||
github.com/gomodule/redigo v1.8.2
|
github.com/gomodule/redigo v1.8.5
|
||||||
github.com/lib/pq v1.10.1
|
github.com/lib/pq v1.10.1
|
||||||
github.com/pkg/errors v0.9.1
|
github.com/pkg/errors v0.9.1
|
||||||
|
github.com/rafaeljusto/redigomock/v3 v3.0.1 // indirect
|
||||||
github.com/stretchr/testify v1.7.0
|
github.com/stretchr/testify v1.7.0
|
||||||
golang.org/x/net v0.0.0-20211013171255-e13a2654a71e
|
golang.org/x/net v0.0.0-20211013171255-e13a2654a71e
|
||||||
golang.org/x/text v0.3.7
|
|
||||||
willnorris.com/go/microformats v1.1.0
|
willnorris.com/go/microformats v1.1.0
|
||||||
)
|
)
|
||||||
|
|
|
||||||
8
go.sum
8
go.sum
|
|
@ -312,6 +312,8 @@ github.com/cpuguy83/go-md2man/v2 v2.0.0-20190314233015-f79a8a8ca69d/go.mod h1:ma
|
||||||
github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
|
github.com/cpuguy83/go-md2man/v2 v2.0.0/go.mod h1:maD7wRr/U5Z6m/iR4s+kqSMx2CaBsrgA7czyZG/E6dU=
|
||||||
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
|
github.com/creack/pty v1.1.7/go.mod h1:lj5s0c3V2DBrqTV7llrYr5NG6My20zk30Fl46Y7DoTY=
|
||||||
github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
github.com/creack/pty v1.1.11/go.mod h1:oKZEueFk5CKHvIhNR5MUki03XCEU+Q6VDXinZuGJ33E=
|
||||||
|
github.com/cristalhq/jwt/v4 v4.0.0-beta h1:TSNUSW7CCvMtX+Rg3vj706BwqIEt9HMvVv6Syifq5mU=
|
||||||
|
github.com/cristalhq/jwt/v4 v4.0.0-beta/go.mod h1:HnYraSNKDRag1DZP92rYHyrjyQHnVEHPNqesmzs+miQ=
|
||||||
github.com/cyphar/filepath-securejoin v0.2.2/go.mod h1:FpkQEhXnPnOthhzymB7CGsFk2G9VLXONKD9G7QGMM+4=
|
github.com/cyphar/filepath-securejoin v0.2.2/go.mod h1:FpkQEhXnPnOthhzymB7CGsFk2G9VLXONKD9G7QGMM+4=
|
||||||
github.com/cznic/mathutil v0.0.0-20180504122225-ca4c9f2c1369/go.mod h1:e6NPNENfs9mPDVNRekM7lKScauxd5kXTr1Mfyig6TDM=
|
github.com/cznic/mathutil v0.0.0-20180504122225-ca4c9f2c1369/go.mod h1:e6NPNENfs9mPDVNRekM7lKScauxd5kXTr1Mfyig6TDM=
|
||||||
github.com/d2g/dhcp4 v0.0.0-20170904100407-a1d1b6c41b1c/go.mod h1:Ct2BUK8SB0YC1SMSibvLzxjeJLnrYEVLULFNiHY9YfQ=
|
github.com/d2g/dhcp4 v0.0.0-20170904100407-a1d1b6c41b1c/go.mod h1:Ct2BUK8SB0YC1SMSibvLzxjeJLnrYEVLULFNiHY9YfQ=
|
||||||
|
|
@ -496,6 +498,9 @@ github.com/golang/snappy v0.0.4 h1:yAGX7huGHXlcLOEtBnF4w7FQwA26wojNCwOYAEhLjQM=
|
||||||
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
github.com/golang/snappy v0.0.4/go.mod h1:/XxbfmMg8lxefKM7IXC3fBNl/7bRcc72aCRzEWrmP2Q=
|
||||||
github.com/gomodule/redigo v1.8.2 h1:H5XSIre1MB5NbPYFp+i1NBbb5qN1W8Y8YAQoAYbkm8k=
|
github.com/gomodule/redigo v1.8.2 h1:H5XSIre1MB5NbPYFp+i1NBbb5qN1W8Y8YAQoAYbkm8k=
|
||||||
github.com/gomodule/redigo v1.8.2/go.mod h1:P9dn9mFrCBvWhGE1wpxx6fgq7BAeLBk+UUUzlpkBYO0=
|
github.com/gomodule/redigo v1.8.2/go.mod h1:P9dn9mFrCBvWhGE1wpxx6fgq7BAeLBk+UUUzlpkBYO0=
|
||||||
|
github.com/gomodule/redigo v1.8.3/go.mod h1:P9dn9mFrCBvWhGE1wpxx6fgq7BAeLBk+UUUzlpkBYO0=
|
||||||
|
github.com/gomodule/redigo v1.8.5 h1:nRAxCa+SVsyjSBrtZmG/cqb6VbTmuRzpg/PoTFlpumc=
|
||||||
|
github.com/gomodule/redigo v1.8.5/go.mod h1:P9dn9mFrCBvWhGE1wpxx6fgq7BAeLBk+UUUzlpkBYO0=
|
||||||
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
github.com/google/btree v0.0.0-20180813153112-4030bb1f1f0c/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||||
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
github.com/google/btree v1.0.0/go.mod h1:lNA+9X1NB3Zf8V7Ke586lFgjr2dZNuvo3lPJSGZ5JPQ=
|
||||||
github.com/google/flatbuffers v2.0.0+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
|
github.com/google/flatbuffers v2.0.0+incompatible/go.mod h1:1AeVuKshWv4vARoZatz6mlQ0JxURH0Kv5+zNeJKJCa8=
|
||||||
|
|
@ -838,6 +843,8 @@ github.com/prometheus/procfs v0.1.3/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4O
|
||||||
github.com/prometheus/procfs v0.2.0/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
|
github.com/prometheus/procfs v0.2.0/go.mod h1:lV6e/gmhEcM9IjHGsFOCxxuZ+z1YqCvr4OA4YeYWdaU=
|
||||||
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
|
github.com/prometheus/procfs v0.6.0/go.mod h1:cz+aTbrPOrUb4q7XlbU9ygM+/jj0fzG6c1xBZuNvfVA=
|
||||||
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
|
github.com/prometheus/tsdb v0.7.1/go.mod h1:qhTCs0VvXwvX/y3TZrWD7rabWM+ijKTux40TwIPHuXU=
|
||||||
|
github.com/rafaeljusto/redigomock/v3 v3.0.1 h1:AUsXTuf+UEMwVEgRHRDYFFCJ1quS2JVDQmTWypjI5mI=
|
||||||
|
github.com/rafaeljusto/redigomock/v3 v3.0.1/go.mod h1:51LNR7Q4YFsi0N+CHr7+FC1Jx2lPLzcRHCPlLO2Qbpw=
|
||||||
github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
github.com/rcrowley/go-metrics v0.0.0-20190826022208-cac0b30c2563/go.mod h1:bCqnVzQkZxMG4s8nGwiZ5l3QUCyqpo9Y+/ZMZ9VjZe4=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20190728182440-6a916e37a237/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20190728182440-6a916e37a237/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
github.com/remyoudompheng/bigfft v0.0.0-20200410134404-eec4a21b6bb0/go.mod h1:qqbHyh8v60DhA7CoWK5oRCqLrMHRGoxYCSS9EjAz6Eo=
|
||||||
|
|
@ -1234,7 +1241,6 @@ golang.org/x/text v0.3.3/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.4/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.5/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
golang.org/x/text v0.3.6/go.mod h1:5Zoc/QRtKVWzQhOtBMvqHzDpF6irO9z98xDceosuGiQ=
|
||||||
golang.org/x/text v0.3.7 h1:olpwvP2KacW1ZWvsR7uQhoyTYvKAupfQrRGBFM352Gk=
|
|
||||||
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
golang.org/x/text v0.3.7/go.mod h1:u+2+/6zg+i71rQMx5EYifcz6MCKuco9NR6JIITiCfzQ=
|
||||||
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20180412165947-fbb02b2291d2/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
golang.org/x/time v0.0.0-20181108054448-85acf8d2951c/go.mod h1:tRJNPiyCQ0inRvYxbN9jk5I+vvW/OXSQhTDSoE431IQ=
|
||||||
|
|
|
||||||
|
|
@ -155,6 +155,7 @@ type Microsub interface {
|
||||||
ItemSearch(channel, query string) ([]Item, error)
|
ItemSearch(channel, query string) ([]Item, error)
|
||||||
|
|
||||||
Events() (chan sse.Message, error)
|
Events() (chan sse.Message, error)
|
||||||
|
RefreshFeeds()
|
||||||
}
|
}
|
||||||
|
|
||||||
// MarshalJSON encodes an Unread value as JSON
|
// MarshalJSON encodes an Unread value as JSON
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,8 @@ import (
|
||||||
"encoding/xml"
|
"encoding/xml"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/text/cases"
|
|
||||||
"golang.org/x/text/language"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func parseRSS1(data []byte) (*Feed, error) {
|
func parseRSS1(data []byte) (*Feed, error) {
|
||||||
|
|
@ -31,8 +29,6 @@ func parseRSS1(data []byte) (*Feed, error) {
|
||||||
out.Description = channel.Description
|
out.Description = channel.Description
|
||||||
out.Link = channel.Link
|
out.Link = channel.Link
|
||||||
out.Image = channel.Image.Image()
|
out.Image = channel.Image.Image()
|
||||||
|
|
||||||
titleCaser := cases.Title(language.English)
|
|
||||||
if channel.MinsToLive != 0 {
|
if channel.MinsToLive != 0 {
|
||||||
sort.Ints(channel.SkipHours)
|
sort.Ints(channel.SkipHours)
|
||||||
next := time.Now().Add(time.Duration(channel.MinsToLive) * time.Minute)
|
next := time.Now().Add(time.Duration(channel.MinsToLive) * time.Minute)
|
||||||
|
|
@ -45,7 +41,7 @@ func parseRSS1(data []byte) (*Feed, error) {
|
||||||
for trying {
|
for trying {
|
||||||
trying = false
|
trying = false
|
||||||
for _, day := range channel.SkipDays {
|
for _, day := range channel.SkipDays {
|
||||||
if titleCaser.String(day) == next.Weekday().String() {
|
if strings.Title(day) == next.Weekday().String() {
|
||||||
next.Add(time.Duration(24-next.Hour()) * time.Hour)
|
next.Add(time.Duration(24-next.Hour()) * time.Hour)
|
||||||
trying = true
|
trying = true
|
||||||
break
|
break
|
||||||
|
|
|
||||||
|
|
@ -5,10 +5,8 @@ import (
|
||||||
"encoding/xml"
|
"encoding/xml"
|
||||||
"fmt"
|
"fmt"
|
||||||
"sort"
|
"sort"
|
||||||
|
"strings"
|
||||||
"time"
|
"time"
|
||||||
|
|
||||||
"golang.org/x/text/cases"
|
|
||||||
"golang.org/x/text/language"
|
|
||||||
)
|
)
|
||||||
|
|
||||||
func parseRSS2(data []byte) (*Feed, error) {
|
func parseRSS2(data []byte) (*Feed, error) {
|
||||||
|
|
@ -40,7 +38,6 @@ func parseRSS2(data []byte) (*Feed, error) {
|
||||||
|
|
||||||
out.Image = channel.Image.Image()
|
out.Image = channel.Image.Image()
|
||||||
if channel.MinsToLive != 0 {
|
if channel.MinsToLive != 0 {
|
||||||
titleCaser := cases.Title(language.English)
|
|
||||||
sort.Ints(channel.SkipHours)
|
sort.Ints(channel.SkipHours)
|
||||||
next := time.Now().Add(time.Duration(channel.MinsToLive) * time.Minute)
|
next := time.Now().Add(time.Duration(channel.MinsToLive) * time.Minute)
|
||||||
for _, hour := range channel.SkipHours {
|
for _, hour := range channel.SkipHours {
|
||||||
|
|
@ -52,7 +49,7 @@ func parseRSS2(data []byte) (*Feed, error) {
|
||||||
for trying {
|
for trying {
|
||||||
trying = false
|
trying = false
|
||||||
for _, day := range channel.SkipDays {
|
for _, day := range channel.SkipDays {
|
||||||
if titleCaser.String(day) == next.Weekday().String() {
|
if strings.Title(day) == next.Weekday().String() {
|
||||||
next.Add(time.Duration(24-next.Hour()) * time.Hour)
|
next.Add(time.Duration(24-next.Hour()) * time.Hour)
|
||||||
trying = true
|
trying = true
|
||||||
break
|
break
|
||||||
|
|
|
||||||
|
|
@ -239,7 +239,6 @@ func (h *microsubHandler) ServeHTTP(w http.ResponseWriter, r *http.Request) {
|
||||||
})
|
})
|
||||||
return
|
return
|
||||||
}
|
}
|
||||||
log.Printf("Searching for %s in %s (%d results)", query, channel, len(items))
|
|
||||||
respondJSON(w, map[string]interface{}{
|
respondJSON(w, map[string]interface{}{
|
||||||
"query": query,
|
"query": query,
|
||||||
"items": items,
|
"items": items,
|
||||||
|
|
|
||||||
|
|
@ -27,6 +27,12 @@ import (
|
||||||
type NullBackend struct {
|
type NullBackend struct {
|
||||||
}
|
}
|
||||||
|
|
||||||
|
// RefreshFeeds refreshes feeds
|
||||||
|
func (b *NullBackend) RefreshFeeds() {
|
||||||
|
// TODO implement me
|
||||||
|
// panic("implement me")
|
||||||
|
}
|
||||||
|
|
||||||
// ChannelsGetList gets no channels
|
// ChannelsGetList gets no channels
|
||||||
func (b *NullBackend) ChannelsGetList() ([]microsub.Channel, error) {
|
func (b *NullBackend) ChannelsGetList() ([]microsub.Channel, error) {
|
||||||
return []microsub.Channel{
|
return []microsub.Channel{
|
||||||
|
|
|
||||||
|
|
@ -53,11 +53,52 @@ func (p *postgresStream) Init() error {
|
||||||
if err != nil {
|
if err != nil {
|
||||||
return fmt.Errorf("database ping failed: %w", err)
|
return fmt.Errorf("database ping failed: %w", err)
|
||||||
}
|
}
|
||||||
|
//
|
||||||
|
// _, err = conn.ExecContext(ctx, `
|
||||||
|
// CREATE TABLE IF NOT EXISTS "channels" (
|
||||||
|
// "id" int primary key generated always as identity,
|
||||||
|
// "name" varchar(255) unique,
|
||||||
|
// "created_at" timestamp DEFAULT current_timestamp
|
||||||
|
// );
|
||||||
|
// `)
|
||||||
|
// if err != nil {
|
||||||
|
// return fmt.Errorf("create channels table failed: %w", err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// _, err = conn.ExecContext(ctx, `
|
||||||
|
// CREATE TABLE IF NOT EXISTS "items" (
|
||||||
|
// "id" int primary key generated always as identity,
|
||||||
|
// "channel_id" int references "channels" on delete cascade,
|
||||||
|
// "uid" varchar(512) not null unique,
|
||||||
|
// "is_read" int default 0,
|
||||||
|
// "data" jsonb,
|
||||||
|
// "created_at" timestamp DEFAULT current_timestamp,
|
||||||
|
// "updated_at" timestamp,
|
||||||
|
// "published_at" timestamp
|
||||||
|
// );
|
||||||
|
// `)
|
||||||
|
// if err != nil {
|
||||||
|
// return fmt.Errorf("create items table failed: %w", err)
|
||||||
|
// }
|
||||||
|
//
|
||||||
|
// _, err = conn.ExecContext(ctx, `ALTER TABLE "items" ALTER COLUMN "data" TYPE jsonb, ALTER COLUMN "uid" TYPE varchar(1024)`)
|
||||||
|
// if err != nil {
|
||||||
|
// return fmt.Errorf("alter items table failed: %w", err)
|
||||||
|
// }
|
||||||
|
|
||||||
|
_, err = conn.ExecContext(ctx, `INSERT INTO "channels" ("uid", "name", "created_at") VALUES ($1, $1, DEFAULT)
|
||||||
|
ON CONFLICT DO NOTHING`, p.channel)
|
||||||
|
if err != nil {
|
||||||
|
return fmt.Errorf("create channel failed: %w", err)
|
||||||
|
}
|
||||||
|
|
||||||
row := conn.QueryRowContext(ctx, `SELECT "id" FROM "channels" WHERE "uid" = $1`, p.channel)
|
row := conn.QueryRowContext(ctx, `SELECT "id" FROM "channels" WHERE "uid" = $1`, p.channel)
|
||||||
|
if row == nil {
|
||||||
|
return fmt.Errorf("fetch channel failed: %w", err)
|
||||||
|
}
|
||||||
err = row.Scan(&p.channelID)
|
err = row.Scan(&p.channelID)
|
||||||
if err == sql.ErrNoRows {
|
if err != nil {
|
||||||
return fmt.Errorf("channel %s not found: %w", p.channel, err)
|
return fmt.Errorf("fetch channel failed while scanning: %w", err)
|
||||||
}
|
}
|
||||||
|
|
||||||
return nil
|
return nil
|
||||||
|
|
@ -86,16 +127,16 @@ WHERE "channel_id" = $1
|
||||||
log.Println(err)
|
log.Println(err)
|
||||||
} else {
|
} else {
|
||||||
args = append(args, b)
|
args = append(args, b)
|
||||||
qb.WriteString(` AND "published_at" > $2`)
|
qb.WriteString(` AND "published_at" < $2`)
|
||||||
}
|
}
|
||||||
} else if after != "" {
|
} else if after != "" {
|
||||||
b, err := time.Parse(time.RFC3339, after)
|
b, err := time.Parse(time.RFC3339, after)
|
||||||
if err == nil {
|
if err == nil {
|
||||||
args = append(args, b)
|
args = append(args, b)
|
||||||
qb.WriteString(` AND "published_at" < $2`)
|
qb.WriteString(` AND "published_at" > $2`)
|
||||||
}
|
}
|
||||||
}
|
}
|
||||||
qb.WriteString(` ORDER BY "published_at" DESC LIMIT 20`)
|
qb.WriteString(` ORDER BY "published_at" DESC LIMIT 10`)
|
||||||
|
|
||||||
rows, err := conn.QueryContext(context.Background(), qb.String(), args...)
|
rows, err := conn.QueryContext(context.Background(), qb.String(), args...)
|
||||||
if err != nil {
|
if err != nil {
|
||||||
|
|
@ -139,12 +180,9 @@ WHERE "channel_id" = $1
|
||||||
return tl, err
|
return tl, err
|
||||||
}
|
}
|
||||||
|
|
||||||
if len(tl.Items) > 0 && hasMoreBefore(conn, tl.Items[0].Published) {
|
// TODO: should only be set of there are more items available
|
||||||
tl.Paging.Before = tl.Items[0].Published
|
tl.Paging.Before = last
|
||||||
}
|
// tl.Paging.After = last
|
||||||
if hasMoreAfter(conn, last) {
|
|
||||||
tl.Paging.After = last
|
|
||||||
}
|
|
||||||
|
|
||||||
if tl.Items == nil {
|
if tl.Items == nil {
|
||||||
tl.Items = []microsub.Item{}
|
tl.Items = []microsub.Item{}
|
||||||
|
|
@ -153,24 +191,6 @@ WHERE "channel_id" = $1
|
||||||
return tl, nil
|
return tl, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
func hasMoreBefore(conn *sql.Conn, before string) bool {
|
|
||||||
row := conn.QueryRowContext(context.Background(), `SELECT COUNT(*) FROM "items" WHERE "published_at" > $1`, before)
|
|
||||||
var count int
|
|
||||||
if err := row.Scan(&count); err == sql.ErrNoRows {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
func hasMoreAfter(conn *sql.Conn, after string) bool {
|
|
||||||
row := conn.QueryRowContext(context.Background(), `SELECT COUNT(*) FROM "items" WHERE "published_at" < $1`, after)
|
|
||||||
var count int
|
|
||||||
if err := row.Scan(&count); err == sql.ErrNoRows {
|
|
||||||
return false
|
|
||||||
}
|
|
||||||
return count > 0
|
|
||||||
}
|
|
||||||
|
|
||||||
// Count
|
// Count
|
||||||
func (p *postgresStream) Count() (int, error) {
|
func (p *postgresStream) Count() (int, error) {
|
||||||
ctx := context.Background()
|
ctx := context.Background()
|
||||||
|
|
@ -179,12 +199,16 @@ func (p *postgresStream) Count() (int, error) {
|
||||||
return -1, err
|
return -1, err
|
||||||
}
|
}
|
||||||
defer conn.Close()
|
defer conn.Close()
|
||||||
var count int
|
|
||||||
row := conn.QueryRowContext(context.Background(), `SELECT COUNT(*) FROM items WHERE channel_id = $1 AND "is_read" = 0`, p.channelID)
|
row := conn.QueryRowContext(context.Background(), `SELECT COUNT(*) FROM items WHERE channel_id = $1 AND "is_read" = 0`, p.channelID)
|
||||||
err = row.Scan(&count)
|
if row == nil {
|
||||||
if err != nil && err == sql.ErrNoRows {
|
|
||||||
return 0, nil
|
return 0, nil
|
||||||
}
|
}
|
||||||
|
var count int
|
||||||
|
err = row.Scan(&count)
|
||||||
|
if err != nil {
|
||||||
|
return -1, err
|
||||||
|
}
|
||||||
|
|
||||||
return count, nil
|
return count, nil
|
||||||
}
|
}
|
||||||
|
|
||||||
|
|
|
||||||
Loading…
Reference in New Issue
Block a user