diff --git a/pkg/timeline/postgres.go b/pkg/timeline/postgres.go index 46fa137..3168c38 100644 --- a/pkg/timeline/postgres.go +++ b/pkg/timeline/postgres.go @@ -1,8 +1,10 @@ package timeline import ( + "context" "database/sql" "fmt" + "log" "strings" "time" @@ -20,13 +22,18 @@ type postgresStream struct { // Init func (p *postgresStream) Init() error { - db := p.database - err := db.Ping() + ctx := context.Background() + conn, err := p.database.Conn(ctx) + if err != nil { + return err + } + defer conn.Close() + err = conn.PingContext(ctx) if err != nil { return fmt.Errorf("database ping failed: %w", err) } - _, err = db.Exec(` + _, err = conn.ExecContext(ctx, ` CREATE TABLE IF NOT EXISTS "channels" ( "id" int primary key generated always as identity, "name" varchar(255) unique, @@ -37,7 +44,7 @@ CREATE TABLE IF NOT EXISTS "channels" ( return fmt.Errorf("create channels table failed: %w", err) } - _, err = db.Exec(` + _, err = conn.ExecContext(ctx, ` CREATE TABLE IF NOT EXISTS "items" ( "id" int primary key generated always as identity, "channel_id" int references "channels" on delete cascade, @@ -53,22 +60,19 @@ CREATE TABLE IF NOT EXISTS "items" ( return fmt.Errorf("create items table failed: %w", err) } - _, err = db.Exec(`INSERT INTO "channels" ("name", "created_at") VALUES ($1, DEFAULT) + _, err = conn.ExecContext(ctx, `INSERT INTO "channels" ("name", "created_at") VALUES ($1, DEFAULT) ON CONFLICT DO NOTHING`, p.channel) if err != nil { return fmt.Errorf("create channel failed: %w", err) } - rows, err := db.Query(`SELECT "id" FROM "channels" WHERE "name" = $1`, p.channel) - if err != nil { + row := conn.QueryRowContext(ctx, `SELECT "id" FROM "channels" WHERE "name" = $1`, p.channel) + if row == nil { return fmt.Errorf("fetch channel failed: %w", err) } - for rows.Next() { - err = rows.Scan(&p.channelID) - if err != nil { - return fmt.Errorf("fetch channel failed while scanning: %w", err) - } - break + err = row.Scan(&p.channelID) + if err != nil { + return fmt.Errorf("fetch channel failed while scanning: %w", err) } return nil @@ -76,6 +80,13 @@ CREATE TABLE IF NOT EXISTS "items" ( // Items func (p *postgresStream) Items(before, after string) (microsub.Timeline, error) { + ctx := context.Background() + conn, err := p.database.Conn(ctx) + if err != nil { + return microsub.Timeline{}, err + } + defer conn.Close() + var args []interface{} args = append(args, p.channelID) var qb strings.Builder @@ -86,7 +97,9 @@ WHERE "channel_id" = $1 `) if before != "" { b, err := time.Parse(time.RFC3339, before) - if err == nil { + if err != nil { + log.Println(err) + } else { args = append(args, b) qb.WriteString(` AND "published_at" < $2`) } @@ -97,9 +110,9 @@ WHERE "channel_id" = $1 qb.WriteString(` AND "published_at" > $2`) } } - qb.WriteString(` ORDER BY "published_at"`) + qb.WriteString(` ORDER BY "published_at" DESC LIMIT 10`) - rows, err := p.database.Query(qb.String(), args...) + rows, err := conn.QueryContext(context.Background(), qb.String(), args...) if err != nil { return microsub.Timeline{}, fmt.Errorf("while query: %w", err) } @@ -127,6 +140,7 @@ WHERE "channel_id" = $1 item.Read = isRead == 1 item.ID = uid + item.Published = publishedAt tl.Items = append(tl.Items, item) } @@ -140,15 +154,22 @@ WHERE "channel_id" = $1 return tl, err } - tl.Paging.Before = first - tl.Paging.After = last + // TODO: should only be set of there are more items available + tl.Paging.Before = last + // tl.Paging.After = last return tl, nil } // Count func (p *postgresStream) Count() (int, error) { - rows, err := p.database.Query("SELECT COUNT(*) FROM items WHERE channel_id = ?", p.channel) + ctx := context.Background() + conn, err := p.database.Conn(ctx) + if err != nil { + return -1, err + } + defer conn.Close() + rows, err := conn.QueryContext(context.Background(), "SELECT COUNT(*) FROM items WHERE channel_id = ?", p.channel) if err != nil { return 0, err } @@ -167,6 +188,13 @@ func (p *postgresStream) Count() (int, error) { // AddItem func (p *postgresStream) AddItem(item microsub.Item) (bool, error) { + ctx := context.Background() + conn, err := p.database.Conn(ctx) + if err != nil { + return false, err + } + defer conn.Close() + t, err := time.Parse("2006-01-02T15:04:05Z0700", item.Published) if err != nil { t2, err := time.Parse("2006-01-02T15:04:05Z07:00", item.Published) @@ -176,7 +204,7 @@ func (p *postgresStream) AddItem(item microsub.Item) (bool, error) { t = t2 } - _, err = p.database.Exec(` + _, err = conn.ExecContext(context.Background(), ` INSERT INTO "items" ("channel_id", "uid", "data", "published_at", "created_at") VALUES ($1, $2, $3, $4, DEFAULT) ON CONFLICT ON CONSTRAINT "items_uid_key" DO UPDATE SET "updated_at" = now() @@ -189,7 +217,13 @@ ON CONFLICT ON CONSTRAINT "items_uid_key" DO UPDATE SET "updated_at" = now() // MarkRead func (p *postgresStream) MarkRead(uids []string) error { - _, err := p.database.Exec(`UPDATE "items" SET is_read = 1 WHERE "uid" IN ($1)`, uids) + ctx := context.Background() + conn, err := p.database.Conn(ctx) + if err != nil { + return err + } + defer conn.Close() + _, err = conn.ExecContext(context.Background(), `UPDATE "items" SET is_read = 1 WHERE "uid" IN ($1)`, uids) if err != nil { return fmt.Errorf("while marking as read: %w", err) }