diff --git a/pkg/timeline/postgres.go b/pkg/timeline/postgres.go index 2dc0c9b..7ac828d 100644 --- a/pkg/timeline/postgres.go +++ b/pkg/timeline/postgres.go @@ -173,18 +173,14 @@ func (p *postgresStream) Count() (int, error) { return -1, err } defer conn.Close() - rows, err := conn.QueryContext(context.Background(), "SELECT COUNT(*) FROM items WHERE channel_id = ?", p.channel) - if err != nil { - return 0, err + row := conn.QueryRowContext(context.Background(), "SELECT COUNT(*) FROM items WHERE channel_id = $1", p.channelID) + if row == nil { + return 0, nil } - var count int - for rows.Next() { - err = rows.Scan(&count) - if err != nil { - return -1, err - } - break + err = row.Scan(&count) + if err != nil { + return -1, err } return count, nil @@ -224,10 +220,10 @@ func (p *postgresStream) MarkRead(uids []string) error { ctx := context.Background() conn, err := p.database.Conn(ctx) if err != nil { - return err + return fmt.Errorf("getting connection: %w", err) } defer conn.Close() - _, err = conn.ExecContext(context.Background(), `UPDATE "items" SET is_read = 1 WHERE "uid" IN ($1)`, pq.Array(uids)) + _, err = conn.ExecContext(context.Background(), `UPDATE "items" SET is_read = 1 WHERE "uid" = ANY($1)`, pq.Array(uids)) if err != nil { return fmt.Errorf("while marking as read: %w", err) }