Skip to main content

WebSocket


Server

Client

Proxy

import (
"errors"
"fmt"
"io"
"net/http"
"net/url"

"github.com/gin-gonic/gin"
"github.com/gobwas/ws"
)

func WebsocketProxy(c *gin.Context) {
connClient, _, _, err := ws.UpgradeHTTP(c.Request, c.Writer)
if err != nil {
c.Status(http.StatusInternalServerError)
return
}
defer connClient.Close()

u, _ := url.Parse("ws://websocket.default.svc.cluster.local:8080/ws")
u.RawQuery = c.Request.URL.RawQuery

dial := ws.Dialer{}
dialHeader := http.Header{}
dial.Header = ws.HandshakeHeaderHTTP(dialHeader)

connServer, _, _, err := dial.Dial(c.Request.Context(), u.String())
if err != nil {
c.Status(http.StatusInternalServerError)
return
}
defer connServer.Close()

errClient := make(chan error, 1)
errServer := make(chan error, 1)

go func() {
for {
header, err := ws.ReadHeader(connClient)
if err != nil {
errClient <- err
}

payload := make([]byte, header.Length)
_, err = io.ReadFull(connClient, payload)
if err != nil {
errClient <- err
}

if err := ws.WriteHeader(connServer, header); err != nil {
errServer <- err
}

if _, err := connServer.Write(payload); err != nil {
errServer <- err
}

if header.OpCode == ws.OpClose {
errServer <- errConnectionClose
}
}
}()

go func() {
for {
header, err := ws.ReadHeader(connServer)
if err != nil {
errServer <- err
}

payload := make([]byte, header.Length)
_, err = io.ReadFull(connServer, payload)
if err != nil {
errServer <- err
}

if err := ws.WriteHeader(connClient, header); err != nil {
errClient <- err
}

if _, err := connClient.Write(payload); err != nil {
errClient <- err
}

if header.OpCode == ws.OpClose {
errClient <- errConnectionClose
}
}
}()

select {
case err = <-errClient:
err = fmt.Errorf("client: %w", err)
case err = <-errServer:
err = fmt.Errorf("server: %w", err)
}

if errors.Is(err, errConnectionClose) {
return
} else {
c.Status(http.StatusInternalServerError)
c.Error(err)
}
}