diff --git a/go.mod b/go.mod index 6d41af507d351..93149c40d1bd9 100644 --- a/go.mod +++ b/go.mod @@ -51,6 +51,7 @@ require ( github.com/google/uuid v1.3.0 github.com/gorilla/feeds v1.1.1 github.com/gorilla/sessions v1.2.1 + github.com/gorilla/websocket v1.4.2 github.com/hashicorp/go-version v1.4.0 github.com/hashicorp/golang-lru v0.5.4 github.com/huandu/xstrings v1.3.2 @@ -194,7 +195,6 @@ require ( github.com/gorilla/handlers v1.5.1 // indirect github.com/gorilla/mux v1.8.0 // indirect github.com/gorilla/securecookie v1.1.1 // indirect - github.com/gorilla/websocket v1.4.2 // indirect github.com/grpc-ecosystem/go-grpc-middleware v1.3.0 // indirect github.com/grpc-ecosystem/go-grpc-prometheus v1.2.0 // indirect github.com/grpc-ecosystem/grpc-gateway v1.16.0 // indirect diff --git a/modules/context/response.go b/modules/context/response.go index 112964dbe14cd..0a844f9d4ae57 100644 --- a/modules/context/response.go +++ b/modules/context/response.go @@ -84,14 +84,31 @@ func (r *Response) Before(f func(ResponseWriter)) { r.befores = append(r.befores, f) } +// hijackerResponse wraps the Response to allow casting as a Hijacker if the underlying response is a hijacker +type hijackerResponse struct { + *Response + http.Hijacker +} + // NewResponse creates a response -func NewResponse(resp http.ResponseWriter) *Response { +func NewResponse(resp http.ResponseWriter) ResponseWriter { if v, ok := resp.(*Response); ok { return v } - return &Response{ + hijacker, ok := resp.(http.Hijacker) + + response := &Response{ ResponseWriter: resp, status: 0, befores: make([]func(ResponseWriter), 0), } + if ok { + // ensure that the Response we return is also hijackable + return hijackerResponse{ + Response: response, + Hijacker: hijacker, + } + } + + return response } diff --git a/routers/web/events/websocket.go b/routers/web/events/websocket.go new file mode 100644 index 0000000000000..7acdcf1df187c --- /dev/null +++ b/routers/web/events/websocket.go @@ -0,0 +1,235 @@ +// Copyright 2022 The Gitea Authors. All rights reserved. +// Use of this source code is governed by a MIT-style +// license that can be found in the LICENSE file. + +package events + +import ( + "net/http" + "net/url" + "time" + + "code.gitea.io/gitea/modules/context" + "code.gitea.io/gitea/modules/eventsource" + "code.gitea.io/gitea/modules/graceful" + "code.gitea.io/gitea/modules/json" + "code.gitea.io/gitea/modules/log" + "code.gitea.io/gitea/modules/setting" + "code.gitea.io/gitea/routers/web/auth" + "github.com/gorilla/websocket" +) + +const ( + writeWait = 10 * time.Second + pongWait = 60 * time.Second + pingPeriod = (pongWait * 9) / 10 + + maximumMessageSize = 2048 + readMessageChanSize = 20 // <- I've put 20 here because it seems like a reasonable buffer but it may to increase +) + +type readMessage struct { + messageType int + message []byte + err error +} + +// Events listens for events +func Websocket(ctx *context.Context) { + upgrader := websocket.Upgrader{ + ReadBufferSize: 1024, + WriteBufferSize: 1024, + CheckOrigin: func(r *http.Request) bool { + origin := r.Header["Origin"] + if len(origin) == 0 { + return true + } + u, err := url.Parse(origin[0]) + if err != nil { + return false + } + appURLURL, err := url.Parse(setting.AppURL) + if err != nil { + return true + } + + return u.Host == appURLURL.Host + }, + } + + // Because http proxies will tend not to pass these headers + ctx.Req.Header.Add("Upgrade", "websocket") + ctx.Req.Header.Add("Connection", "Upgrade") + + conn, err := upgrader.Upgrade(ctx.Resp, ctx.Req, nil) + if err != nil { + log.Error("Unable to upgrade due to error: %v", err) + return + } + defer conn.Close() + + notify := ctx.Done() + shutdownCtx := graceful.GetManager().ShutdownContext() + + eventChan := make(<-chan *eventsource.Event) + uid := int64(0) + unregister := func() {} + if ctx.IsSigned { + uid = ctx.Doer.ID + eventChan = eventsource.GetManager().Register(uid) + unregister = func() { + go func() { + eventsource.GetManager().Unregister(uid, eventChan) + // ensure the messageChan is closed + for { + _, ok := <-eventChan + if !ok { + break + } + } + }() + } + } + defer unregister() + + readMessageChan := make(chan readMessage, readMessageChanSize) + go readMessagesFromConnToChan(conn, readMessageChan) + + pingTicker := time.NewTicker(pingPeriod) + + for { + select { + case <-notify: + return + case <-shutdownCtx.Done(): + return + case <-pingTicker.C: + // ensure that we're not already cancelled + select { + case <-notify: + return + case <-shutdownCtx.Done(): + return + default: + } + + if err := conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil { + log.Error("unable to SetWriteDeadline: %v", err) + return + } + if err := conn.WriteMessage(websocket.PingMessage, []byte{}); err != nil { + log.Error("unable to send PingMessage: %v", err) + return + } + case message, ok := <-readMessageChan: + if !ok { + break + } + + // ensure that we're not already cancelled + select { + case <-notify: + return + case <-shutdownCtx.Done(): + return + default: + } + + // FIXME: HANDLE MESSAGES + log.Info("Got Message: %d:%s:%v", message.messageType, message.message, message.err) + case event, ok := <-eventChan: + if !ok { + break + } + + // ensure that we're not already cancelled + select { + case <-notify: + return + case <-shutdownCtx.Done(): + return + default: + } + + // Handle events + if event.Name == "logout" { + if ctx.Session.ID() == event.Data { + event = &eventsource.Event{ + Name: "logout", + Data: "here", + } + _ = writeEvent(conn, event) + go unregister() + auth.HandleSignOut(ctx) + break + } + // Replace the event - we don't want to expose the session ID to the user + event = &eventsource.Event{ + Name: "logout", + Data: "elsewhere", + } + } + if err := writeEvent(conn, event); err != nil { + return + } + } + } +} + +func readMessagesFromConnToChan(conn *websocket.Conn, messageChan chan readMessage) { + defer func() { + close(messageChan) // Please note: this has to be within a wrapping anonymous func otherwise it will be evaluated when creating the defer + _ = conn.Close() + }() + conn.SetReadLimit(maximumMessageSize) + if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil { + log.Error("unable to SetReadDeadline: %v", err) + return + } + conn.SetPongHandler(func(string) error { + if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil { + log.Error("unable to SetReadDeadline: %v", err) + } + return nil + }) + + for { + messageType, message, err := conn.ReadMessage() + messageChan <- readMessage{ + messageType: messageType, + message: message, + err: err, + } + if err != nil { + // don't need to handle the error here as it is passed down the channel + return + } + if err := conn.SetReadDeadline(time.Now().Add(pongWait)); err != nil { + log.Error("unable to SetReadDeadline: %v", err) + return + } + } +} + +func writeEvent(conn *websocket.Conn, event *eventsource.Event) error { + if err := conn.SetWriteDeadline(time.Now().Add(writeWait)); err != nil { + log.Error("unable to SetWriteDeadline: %v", err) + return err + } + + w, err := conn.NextWriter(websocket.TextMessage) + if err != nil { + log.Error("Unable to get writer for websocket %v", err) + return err + } + + if err := json.NewEncoder(w).Encode(event); err != nil { + log.Error("Unable to create encoder for %v %v", event, err) + return err + } + if err := w.Close(); err != nil { + log.Warn("Unable to close writer for websocket %v", err) + return err + } + return nil +} diff --git a/routers/web/web.go b/routers/web/web.go index b4e8666c44fd2..897748da78be9 100644 --- a/routers/web/web.go +++ b/routers/web/web.go @@ -365,6 +365,7 @@ func RegisterRoutes(m *web.Route) { }, reqSignOut) m.Any("/user/events", routing.MarkLongPolling, events.Events) + m.Any("/user/websocket", routing.MarkLongPolling, events.Websocket) m.Group("/login/oauth", func() { m.Get("/authorize", bindIgnErr(forms.AuthorizationForm{}), auth.AuthorizeOAuth) diff --git a/web_src/js/features/eventsource.sharedworker.js b/web_src/js/features/eventsource.sharedworker.js index 824ccfea79f84..95b2e1bbd3431 100644 --- a/web_src/js/features/eventsource.sharedworker.js +++ b/web_src/js/features/eventsource.sharedworker.js @@ -46,9 +46,13 @@ class Source { if (this.listening[eventType]) return; this.listening[eventType] = true; this.eventSource.addEventListener(eventType, (event) => { + let data; + if (event.data) { + data = JSON.parse(event.data); + } this.notifyClients({ type: eventType, - data: event.data + data }); }); } diff --git a/web_src/js/features/notification.js b/web_src/js/features/notification.js index 36df196cac2d8..24d83057c51a0 100644 --- a/web_src/js/features/notification.js +++ b/web_src/js/features/notification.js @@ -24,21 +24,17 @@ export function initNotificationsTable() { }); } -async function receiveUpdateCount(event) { +async function receiveUpdateCount(data, document) { try { - const data = JSON.parse(event.data); - - const notificationCount = document.querySelector('.notification_count'); - if (data.Count > 0) { - notificationCount.classList.remove('hidden'); - } else { - notificationCount.classList.add('hidden'); + const notificationCounts = document.querySelectorAll('.notification_count'); + for (const count of notificationCounts) { + count.classList.toggle('hidden', data.Count === 0); + count.textContent = `${data.Count}`; } - notificationCount.textContent = `${data.Count}`; await updateNotificationTable(); } catch (error) { - console.error(error, event); + console.error(error, data); } } @@ -49,26 +45,39 @@ export function initNotificationCount() { return; } - if (notificationSettings.EventSourceUpdateTime > 0 && !!window.EventSource && window.SharedWorker) { + let worker; + let workerUrl; + + if (notificationSettings.EventSourceUpdateTime > 0 && !!window.WebSocket && window.SharedWorker) { + // Try to connect to the event source via the shared worker first + worker = new SharedWorker(`${__webpack_public_path__}js/websocket.sharedworker.js`, 'notification-worker'); + workerUrl = `${window.location.origin}${appSubUrl}/user/websocket`; + } else if (notificationSettings.EventSourceUpdateTime > 0 && !!window.EventSource && window.SharedWorker) { // Try to connect to the event source via the shared worker first - const worker = new SharedWorker(`${__webpack_public_path__}js/eventsource.sharedworker.js`, 'notification-worker'); + worker = new SharedWorker(`${__webpack_public_path__}js/eventsource.sharedworker.js`, 'notification-worker'); + workerUrl = `${window.location.origin}${appSubUrl}/user/events`; + } + + const currentDocument = document; + + if (worker) { worker.addEventListener('error', (event) => { - console.error(event); + console.error('error from listener: ', event); }); worker.port.addEventListener('messageerror', () => { console.error('Unable to deserialize message'); }); worker.port.postMessage({ type: 'start', - url: `${window.location.origin}${appSubUrl}/user/events`, + url: workerUrl, }); worker.port.addEventListener('message', (event) => { if (!event.data || !event.data.type) { - console.error(event); + console.error('Unexpected event:', event); return; } if (event.data.type === 'notification-count') { - const _promise = receiveUpdateCount(event.data); + const _promise = receiveUpdateCount(event.data.data, currentDocument); } else if (event.data.type === 'error') { console.error(event.data); } else if (event.data.type === 'logout') { diff --git a/web_src/js/features/stopwatch.js b/web_src/js/features/stopwatch.js index d63da4155af27..a506dfde6c65a 100644 --- a/web_src/js/features/stopwatch.js +++ b/web_src/js/features/stopwatch.js @@ -26,9 +26,19 @@ export function initStopwatch() { $(this).parent().trigger('submit'); }); - if (notificationSettings.EventSourceUpdateTime > 0 && !!window.EventSource && window.SharedWorker) { + let worker; + let workerUrl; + if (notificationSettings.EventSourceUpdateTime > 0 && !!window.WebSocket && window.SharedWorker) { // Try to connect to the event source via the shared worker first - const worker = new SharedWorker(`${__webpack_public_path__}js/eventsource.sharedworker.js`, 'notification-worker'); + worker = new SharedWorker(`${__webpack_public_path__}js/websocket.sharedworker.js`, 'notification-worker'); + workerUrl = `${window.location.origin}${appSubUrl}/user/websocket`; + } else if (notificationSettings.EventSourceUpdateTime > 0 && !!window.EventSource && window.SharedWorker) { + // Try to connect to the event source via the shared worker first + worker = new SharedWorker(`${__webpack_public_path__}js/eventsource.sharedworker.js`, 'notification-worker'); + workerUrl = `${window.location.origin}${appSubUrl}/user/events`; + } + + if (worker) { worker.addEventListener('error', (event) => { console.error(event); }); @@ -37,7 +47,7 @@ export function initStopwatch() { }); worker.port.postMessage({ type: 'start', - url: `${window.location.origin}${appSubUrl}/user/events`, + url: workerUrl, }); worker.port.addEventListener('message', (event) => { if (!event.data || !event.data.type) { @@ -45,7 +55,7 @@ export function initStopwatch() { return; } if (event.data.type === 'stopwatches') { - updateStopwatchData(JSON.parse(event.data.data)); + updateStopwatchData(event.data.data); } else if (event.data.type === 'error') { console.error(event.data); } else if (event.data.type === 'logout') { diff --git a/web_src/js/features/websocket.sharedworker.js b/web_src/js/features/websocket.sharedworker.js new file mode 100644 index 0000000000000..096f485895a34 --- /dev/null +++ b/web_src/js/features/websocket.sharedworker.js @@ -0,0 +1,172 @@ +const sourcesByUrl = {}; +const sourcesByPort = {}; + +class Source { + constructor(url) { + this.url = url.replace(/^http/, 'ws'); + this.webSocket = new WebSocket(this.url); + this.listening = {}; + this.clients = []; + this.listen('open'); + this.listen('close'); + this.listen('logout'); + this.listen('notification-count'); + this.listen('stopwatches'); + this.listen('error'); + this.webSocket.addEventListener('error', (error) => { + this.lastError = error; + }); + this.webSocket.addEventListener('message', (event) => { + const message = JSON.parse(event.data); + if (!message) { + return; + } + if (this.listening[message.Name]) { + this.notifyClients({ + type: message.Name, + data: message.Data + }); + } + }); + this.webSocket.addEventListener('close', (event) => { + if (!this.webSocket) { + return; + } + const oldWebSocket = this.webSocket; + this.webSocket = null; + this.notifyClients({ + type: 'close', + data: event + }); + oldWebSocket.close(); + }); + } + + register(port) { + if (this.clients.includes(port)) return; + + this.clients.push(port); + + port.postMessage({ + type: 'status', + message: `registered to ${this.url}`, + }); + + if (!this.webSocket) { + if (this.lastError) { + port.postMessage({ + type: 'error', + message: `websocket disconnected: ${this.lastError}` + }); + } else { + port.postMessage({ + type: 'error', + message: 'websocket disconnected' + }); + } + } + } + + deregister(port) { + const portIdx = this.clients.indexOf(port); + if (portIdx < 0) { + return this.clients.length; + } + this.clients.splice(portIdx, 1); + return this.clients.length; + } + + close() { + if (!this.webSocket) return; + const oldWebSocket = this.webSocket; + this.webSocket = null; + oldWebSocket.close(); + } + + listen(eventType) { + if (this.listening[eventType]) return; + this.listening[eventType] = true; + } + + notifyClients(event) { + for (const client of this.clients) { + client.postMessage(event); + } + } + + status(port) { + port.postMessage({ + type: 'status', + message: `url: ${this.url} readyState: ${this.webSocket.readyState}`, + }); + } +} + +self.addEventListener('connect', (e) => { + for (const port of e.ports) { + port.addEventListener('message', (event) => { + if (event.data.type === 'start') { + const url = event.data.url; + if (sourcesByUrl[url]) { + // we have a Source registered to this url + const source = sourcesByUrl[url]; + if (source.webSocket) { + source.register(port); + sourcesByPort[port] = source; + return; + } + sourcesByUrl[url] = null; + } + let source = sourcesByPort[port]; + if (source) { + if (source.webSocket && source.url === url) return; + + // How this has happened I don't understand... + // deregister from that source + const count = source.deregister(port); + // Clean-up + if (count === 0) { + source.close(); + sourcesByUrl[source.url] = null; + } + } + // Create a new Source + source = new Source(url); + source.register(port); + sourcesByUrl[url] = source; + sourcesByPort[port] = source; + } else if (event.data.type === 'listen') { + const source = sourcesByPort[port]; + source.listen(event.data.eventType); + } else if (event.data.type === 'close') { + const source = sourcesByPort[port]; + + if (!source) return; + + const count = source.deregister(port); + if (count === 0) { + source.close(); + sourcesByUrl[source.url] = null; + sourcesByPort[port] = null; + } + } else if (event.data.type === 'status') { + const source = sourcesByPort[port]; + if (!source) { + port.postMessage({ + type: 'status', + message: 'not connected', + }); + return; + } + source.status(port); + } else { + // just send it back + port.postMessage({ + type: 'error', + message: `received but don't know how to handle: ${event.data}`, + }); + } + }); + port.start(); + } +}); diff --git a/webpack.config.js b/webpack.config.js index 5109103f7faf7..c865c472f017d 100644 --- a/webpack.config.js +++ b/webpack.config.js @@ -62,6 +62,9 @@ export default { 'eventsource.sharedworker': [ fileURLToPath(new URL('web_src/js/features/eventsource.sharedworker.js', import.meta.url)), ], + 'websocket.sharedworker': [ + fileURLToPath(new URL('web_src/js/features/websocket.sharedworker.js', import.meta.url)), + ], ...themes, }, devtool: false,