Skip to content

Commit

Permalink
Random tiny changes
Browse files Browse the repository at this point in the history
  • Loading branch information
binwiederhier committed May 30, 2023
1 parent 0d2f0dd commit ad761f4
Show file tree
Hide file tree
Showing 7 changed files with 32 additions and 63 deletions.
2 changes: 1 addition & 1 deletion go.mod
Original file line number Diff line number Diff line change
Expand Up @@ -27,6 +27,7 @@ require github.com/pkg/errors v0.9.1 // indirect

require (
firebase.google.com/go/v4 v4.11.0
github.com/SherClockHolmes/webpush-go v1.2.0
github.com/prometheus/client_golang v1.15.1
github.com/stripe/stripe-go/v74 v74.18.0
)
Expand All @@ -39,7 +40,6 @@ require (
cloud.google.com/go/longrunning v0.4.2 // indirect
github.com/AlekSi/pointer v1.2.0 // indirect
github.com/MicahParks/keyfunc v1.9.0 // indirect
github.com/SherClockHolmes/webpush-go v1.2.0 // indirect
github.com/beorn7/perks v1.0.1 // indirect
github.com/cespare/xxhash/v2 v2.2.0 // indirect
github.com/davecgh/go-spew v1.1.1 // indirect
Expand Down
4 changes: 1 addition & 3 deletions server/server_account.go
Original file line number Diff line number Diff line change
Expand Up @@ -171,9 +171,7 @@ func (s *Server) handleAccountDelete(w http.ResponseWriter, r *http.Request, v *
return errHTTPBadRequestIncorrectPasswordConfirmation
}
if s.webPush != nil {
err := s.webPush.ExpireWebPushForUser(u.Name)

if err != nil {
if err := s.webPush.RemoveByUserID(u.ID); err != nil {
logvr(v, r).Err(err).Warn("Error removing web push subscriptions for %s", u.Name)
}
}
Expand Down
10 changes: 2 additions & 8 deletions server/server_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -2598,12 +2598,8 @@ func newTestConfigWithAuthFile(t *testing.T) *Config {

func newTestConfigWithWebPush(t *testing.T) *Config {
conf := newTestConfig(t)

privateKey, publicKey, err := webpush.GenerateVAPIDKeys()
if err != nil {
t.Fatal(err)
}

require.Nil(t, err)
conf.WebPushEnabled = true
conf.WebPushSubscriptionsFile = filepath.Join(t.TempDir(), "subscriptions.db")
conf.WebPushEmailAddress = "[email protected]"
Expand All @@ -2614,9 +2610,7 @@ func newTestConfigWithWebPush(t *testing.T) *Config {

func newTestServer(t *testing.T, config *Config) *Server {
server, err := New(config)
if err != nil {
t.Fatal(err)
}
require.Nil(t, err)
return server
}

Expand Down
34 changes: 7 additions & 27 deletions server/server_web_push.go
Original file line number Diff line number Diff line change
Expand Up @@ -10,15 +10,8 @@ import (
)

func (s *Server) handleTopicWebPushSubscribe(w http.ResponseWriter, r *http.Request, v *visitor) error {
var username string
u := v.User()
if u != nil {
username = u.Name
}

var sub webPushSubscribePayload
err := json.NewDecoder(r.Body).Decode(&sub)

if err != nil || sub.BrowserSubscription.Endpoint == "" || sub.BrowserSubscription.Keys.P256dh == "" || sub.BrowserSubscription.Keys.Auth == "" {
return errHTTPBadRequestWebPushSubscriptionInvalid
}
Expand All @@ -27,12 +20,9 @@ func (s *Server) handleTopicWebPushSubscribe(w http.ResponseWriter, r *http.Requ
if err != nil {
return err
}

err = s.webPush.AddSubscription(topic.ID, username, sub)
if err != nil {
if err = s.webPush.AddSubscription(topic.ID, v.MaybeUserID(), sub); err != nil {
return err
}

return s.writeJSON(w, newSuccessResponse())
}

Expand All @@ -59,7 +49,7 @@ func (s *Server) handleTopicWebPushUnsubscribe(w http.ResponseWriter, r *http.Re
}

func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
subscriptions, err := s.webPush.GetSubscriptionsForTopic(m.Topic)
subscriptions, err := s.webPush.SubscriptionsForTopic(m.Topic)
if err != nil {
logvm(v, m).Err(err).Warn("Unable to publish web push messages")
return
Expand All @@ -69,21 +59,17 @@ func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {

// Importing the emojis in the service worker would add unnecessary complexity,
// simply do it here for web push notifications instead
var titleWithDefault string
var formattedTitle string

var titleWithDefault, formattedTitle string
emojis, _, err := toEmojis(m.Tags)
if err != nil {
logvm(v, m).Err(err).Fields(ctx).Debug("Unable to publish web push message")
return
}

if m.Title == "" {
titleWithDefault = m.Topic
} else {
titleWithDefault = m.Title
}

if len(emojis) > 0 {
formattedTitle = fmt.Sprintf("%s %s", strings.Join(emojis[:], " "), titleWithDefault)
} else {
Expand All @@ -92,7 +78,7 @@ func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {

for i, xi := range subscriptions {
go func(i int, sub webPushSubscription) {
ctx := log.Context{"endpoint": sub.BrowserSubscription.Endpoint, "username": sub.Username, "topic": m.Topic, "message_id": m.ID}
ctx := log.Context{"endpoint": sub.BrowserSubscription.Endpoint, "username": sub.UserID, "topic": m.Topic, "message_id": m.ID}

payload := &webPushPayload{
SubscriptionID: fmt.Sprintf("%s/%s", s.config.BaseURL, m.Topic),
Expand All @@ -110,31 +96,25 @@ func (s *Server) publishToWebPushEndpoints(v *visitor, m *message) {
Subscriber: s.config.WebPushEmailAddress,
VAPIDPublicKey: s.config.WebPushPublicKey,
VAPIDPrivateKey: s.config.WebPushPrivateKey,
// deliverability on iOS isn't great with lower urgency values,
// Deliverability on iOS isn't great with lower urgency values,
// and thus we can't really map lower ntfy priorities to lower urgency values
Urgency: webpush.UrgencyHigh,
})

if err != nil {
logvm(v, m).Err(err).Fields(ctx).Debug("Unable to publish web push message")

err = s.webPush.ExpireWebPushEndpoint(sub.BrowserSubscription.Endpoint)
if err != nil {
if err := s.webPush.RemoveByEndpoint(sub.BrowserSubscription.Endpoint); err != nil {
logvm(v, m).Err(err).Fields(ctx).Warn("Unable to expire subscription")
}

return
}

// May want to handle at least 429 differently, but for now treat all errors the same
if !(200 <= resp.StatusCode && resp.StatusCode <= 299) {
logvm(v, m).Fields(ctx).Field("response", resp).Debug("Unable to publish web push message")

err = s.webPush.ExpireWebPushEndpoint(sub.BrowserSubscription.Endpoint)
if err != nil {
if err := s.webPush.RemoveByEndpoint(sub.BrowserSubscription.Endpoint); err != nil {
logvm(v, m).Err(err).Fields(ctx).Warn("Unable to expire subscription")
}

return
}
}(i, xi)
Expand Down
17 changes: 7 additions & 10 deletions server/server_web_push_test.go
Original file line number Diff line number Diff line change
Expand Up @@ -5,6 +5,7 @@ import (
"io"
"net/http"
"net/http/httptest"
"strings"
"sync/atomic"
"testing"

Expand Down Expand Up @@ -41,7 +42,7 @@ func TestServer_WebPush_TopicSubscribe(t *testing.T) {
require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())

subs, err := s.webPush.GetSubscriptionsForTopic("test-topic")
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
if err != nil {
t.Fatal(err)
}
Expand All @@ -50,7 +51,7 @@ func TestServer_WebPush_TopicSubscribe(t *testing.T) {
require.Equal(t, subs[0].BrowserSubscription.Endpoint, "https://example.com/webpush")
require.Equal(t, subs[0].BrowserSubscription.Keys.P256dh, "p256dh-key")
require.Equal(t, subs[0].BrowserSubscription.Keys.Auth, "auth-key")
require.Equal(t, subs[0].Username, "")
require.Equal(t, subs[0].UserID, "")
}

func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
Expand All @@ -64,17 +65,13 @@ func TestServer_WebPush_TopicSubscribeProtected_Allowed(t *testing.T) {
response := request(t, s, "POST", "/test-topic/web-push/subscribe", webPushSubscribePayloadExample, map[string]string{
"Authorization": util.BasicAuth("ben", "ben"),
})

require.Equal(t, 200, response.Code)
require.Equal(t, `{"success":true}`+"\n", response.Body.String())

subs, err := s.webPush.GetSubscriptionsForTopic("test-topic")
if err != nil {
t.Fatal(err)
}

subs, err := s.webPush.SubscriptionsForTopic("test-topic")
require.Nil(t, err)
require.Len(t, subs, 1)
require.Equal(t, subs[0].Username, "ben")
require.True(t, strings.HasPrefix(subs[0].UserID, "u_"))
}

func TestServer_WebPush_TopicSubscribeProtected_Denied(t *testing.T) {
Expand Down Expand Up @@ -203,7 +200,7 @@ func addSubscription(t *testing.T, s *Server, topic string, url string) {
}

func requireSubscriptionCount(t *testing.T, s *Server, topic string, expectedLength int) {
subs, err := s.webPush.GetSubscriptionsForTopic("test-topic")
subs, err := s.webPush.SubscriptionsForTopic("test-topic")
if err != nil {
t.Fatal(err)
}
Expand Down
4 changes: 2 additions & 2 deletions server/types.go
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,7 @@ type message struct {
PollID string `json:"poll_id,omitempty"`
Encoding string `json:"encoding,omitempty"` // empty for raw UTF-8, or "base64" for encoded bytes
Sender netip.Addr `json:"-"` // IP address of uploader, used for rate limiting
User string `json:"-"` // Username of the uploader, used to associated attachments
User string `json:"-"` // UserID of the uploader, used to associated attachments
}

func (m *message) Context() log.Context {
Expand Down Expand Up @@ -476,7 +476,7 @@ type webPushPayload struct {

type webPushSubscription struct {
BrowserSubscription webpush.Subscription
Username string
UserID string
}

type webPushSubscribePayload struct {
Expand Down
24 changes: 12 additions & 12 deletions server/web_push.go
Original file line number Diff line number Diff line change
Expand Up @@ -12,7 +12,7 @@ const (
CREATE TABLE IF NOT EXISTS subscriptions (
id INTEGER PRIMARY KEY AUTOINCREMENT,
topic TEXT NOT NULL,
username TEXT,
user_id TEXT,
endpoint TEXT NOT NULL,
key_auth TEXT NOT NULL,
key_p256dh TEXT NOT NULL,
Expand All @@ -24,14 +24,14 @@ const (
COMMIT;
`
insertWebPushSubscriptionQuery = `
INSERT OR REPLACE INTO subscriptions (topic, username, endpoint, key_auth, key_p256dh)
INSERT OR REPLACE INTO subscriptions (topic, user_id, endpoint, key_auth, key_p256dh)
VALUES (?, ?, ?, ?, ?)
`
deleteWebPushSubscriptionByEndpointQuery = `DELETE FROM subscriptions WHERE endpoint = ?`
deleteWebPushSubscriptionByUsernameQuery = `DELETE FROM subscriptions WHERE username = ?`
deleteWebPushSubscriptionByUserIDQuery = `DELETE FROM subscriptions WHERE user_id = ?`
deleteWebPushSubscriptionByTopicAndEndpointQuery = `DELETE FROM subscriptions WHERE topic = ? AND endpoint = ?`

selectWebPushSubscriptionsForTopicQuery = `SELECT endpoint, key_auth, key_p256dh, username FROM subscriptions WHERE topic = ?`
selectWebPushSubscriptionsForTopicQuery = `SELECT endpoint, key_auth, key_p256dh, user_id FROM subscriptions WHERE topic = ?`

selectWebPushSubscriptionsCountQuery = `SELECT COUNT(*) FROM subscriptions`
)
Expand Down Expand Up @@ -69,11 +69,11 @@ func setupNewSubscriptionsDB(db *sql.DB) error {
return nil
}

func (c *webPushStore) AddSubscription(topic string, username string, subscription webPushSubscribePayload) error {
func (c *webPushStore) AddSubscription(topic string, userID string, subscription webPushSubscribePayload) error {
_, err := c.db.Exec(
insertWebPushSubscriptionQuery,
topic,
username,
userID,
subscription.BrowserSubscription.Endpoint,
subscription.BrowserSubscription.Keys.Auth,
subscription.BrowserSubscription.Keys.P256dh,
Expand All @@ -90,7 +90,7 @@ func (c *webPushStore) RemoveSubscription(topic string, endpoint string) error {
return err
}

func (c *webPushStore) GetSubscriptionsForTopic(topic string) (subscriptions []webPushSubscription, err error) {
func (c *webPushStore) SubscriptionsForTopic(topic string) (subscriptions []webPushSubscription, err error) {
rows, err := c.db.Query(selectWebPushSubscriptionsForTopicQuery, topic)
if err != nil {
return nil, err
Expand All @@ -100,7 +100,7 @@ func (c *webPushStore) GetSubscriptionsForTopic(topic string) (subscriptions []w
var data []webPushSubscription
for rows.Next() {
i := webPushSubscription{}
err = rows.Scan(&i.BrowserSubscription.Endpoint, &i.BrowserSubscription.Keys.Auth, &i.BrowserSubscription.Keys.P256dh, &i.Username)
err = rows.Scan(&i.BrowserSubscription.Endpoint, &i.BrowserSubscription.Keys.Auth, &i.BrowserSubscription.Keys.P256dh, &i.UserID)
if err != nil {
return nil, err
}
Expand All @@ -109,18 +109,18 @@ func (c *webPushStore) GetSubscriptionsForTopic(topic string) (subscriptions []w
return data, nil
}

func (c *webPushStore) ExpireWebPushEndpoint(endpoint string) error {
func (c *webPushStore) RemoveByEndpoint(endpoint string) error {
_, err := c.db.Exec(
deleteWebPushSubscriptionByEndpointQuery,
endpoint,
)
return err
}

func (c *webPushStore) ExpireWebPushForUser(username string) error {
func (c *webPushStore) RemoveByUserID(userID string) error {
_, err := c.db.Exec(
deleteWebPushSubscriptionByUsernameQuery,
username,
deleteWebPushSubscriptionByUserIDQuery,
userID,
)
return err
}
Expand Down

0 comments on commit ad761f4

Please sign in to comment.