diff --git a/coderd/inboxnotifications.go b/coderd/inboxnotifications.go index 4bb3f9ec953aa..c26f435dd2b0b 100644 --- a/coderd/inboxnotifications.go +++ b/coderd/inboxnotifications.go @@ -21,7 +21,6 @@ import ( "github.com/coder/coder/v2/coderd/pubsub" markdown "github.com/coder/coder/v2/coderd/render" "github.com/coder/coder/v2/codersdk" - "github.com/coder/coder/v2/codersdk/wsjson" "github.com/coder/websocket" ) @@ -127,6 +126,7 @@ func (api *API) watchInboxNotifications(rw http.ResponseWriter, r *http.Request) templates = p.UUIDs(vals, []uuid.UUID{}, "templates") readStatus = p.String(vals, "all", "read_status") format = p.String(vals, notificationFormatMarkdown, "format") + logger = api.Logger.Named("inbox_notifications_watcher") ) p.ErrorExcessParams(vals) if len(p.Errors) > 0 { @@ -214,11 +214,17 @@ func (api *API) watchInboxNotifications(rw http.ResponseWriter, r *http.Request) return } - go httpapi.Heartbeat(ctx, conn) - defer conn.Close(websocket.StatusNormalClosure, "connection closed") + ctx, cancel := context.WithCancel(ctx) + defer cancel() - encoder := wsjson.NewEncoder[codersdk.GetInboxNotificationResponse](conn, websocket.MessageText) - defer encoder.Close(websocket.StatusNormalClosure) + _ = conn.CloseRead(context.Background()) + + ctx, wsNetConn := codersdk.WebsocketNetConn(ctx, conn, websocket.MessageText) + defer wsNetConn.Close() + + go httpapi.HeartbeatClose(ctx, logger, cancel, conn) + + encoder := json.NewEncoder(wsNetConn) // Log the request immediately instead of after it completes. if rl := loggermw.RequestLoggerFromContext(ctx); rl != nil { @@ -227,8 +233,12 @@ func (api *API) watchInboxNotifications(rw http.ResponseWriter, r *http.Request) for { select { + case <-api.ctx.Done(): + return + case <-ctx.Done(): return + case notif := <-notificationCh: unreadCount, err := api.Database.CountUnreadInboxNotificationsByUserID(ctx, apikey.UserID) if err != nil {