Compare commits
1 Commits
| Author | SHA1 | Date | |
|---|---|---|---|
|
|
66437453f5 |
242
yesod-websockets/WaiWS.hs
Normal file
242
yesod-websockets/WaiWS.hs
Normal file
@ -0,0 +1,242 @@
|
|||||||
|
{-# LANGUAGE DeriveDataTypeable #-}
|
||||||
|
{-# LANGUAGE OverloadedStrings #-}
|
||||||
|
{-# LANGUAGE RecordWildCards #-}
|
||||||
|
module WaiWS where
|
||||||
|
|
||||||
|
import Network.Wai
|
||||||
|
import Control.Exception (Exception, throwIO, assert)
|
||||||
|
import Control.Applicative ((<$>))
|
||||||
|
import Control.Monad (when, forever, unless)
|
||||||
|
import Data.Typeable (Typeable)
|
||||||
|
import Network.HTTP.Types (status200, status404)
|
||||||
|
import Blaze.ByteString.Builder
|
||||||
|
import Data.Monoid ((<>), mempty)
|
||||||
|
import qualified Crypto.Hash.SHA1 as SHA1
|
||||||
|
import Debug.Trace
|
||||||
|
import Data.Word (Word8, Word32, Word64)
|
||||||
|
import Data.ByteString (ByteString)
|
||||||
|
import Data.Bits ((.|.), testBit, clearBit, shiftL, (.&.), Bits, xor, shiftR)
|
||||||
|
import qualified Data.Map as Map
|
||||||
|
import Data.Maybe (isJust)
|
||||||
|
import qualified Data.ByteString as S
|
||||||
|
import qualified Data.ByteString.Char8 as S8
|
||||||
|
import qualified Data.ByteString.Base64 as B64
|
||||||
|
import Data.IORef
|
||||||
|
import Data.Char (toUpper)
|
||||||
|
import qualified Data.Conduit as C
|
||||||
|
|
||||||
|
|
||||||
|
data Connection = Connection
|
||||||
|
{ connSend :: Bool -> ByteString -> IO ()
|
||||||
|
, connRecv :: IO ByteString
|
||||||
|
}
|
||||||
|
|
||||||
|
websocketsApp :: Request -> Maybe (C.Source IO ByteString -> C.Sink ByteString IO () -> (WaiWS.Connection -> IO a) -> IO a)
|
||||||
|
websocketsApp req
|
||||||
|
-- FIXME handle keep-alive, Upgrade | lookup "connection" reqhs /= Just "Upgrade" = backup sendResponse
|
||||||
|
| lookup "upgrade" reqhs /= Just "websocket" = Nothing
|
||||||
|
| lookup "sec-websocket-version" reqhs /= Just "13" = Nothing
|
||||||
|
| Just key <- lookup "sec-websocket-key" reqhs = Just $ \src0 sink app -> do
|
||||||
|
(rsrc0, ()) <- src0 C.$$+ return ()
|
||||||
|
rsrcRef <- newIORef rsrc0
|
||||||
|
let recv = do
|
||||||
|
rsrc <- readIORef rsrcRef
|
||||||
|
(rsrc', mbs) <- rsrc C.$$++ C.await
|
||||||
|
writeIORef rsrcRef rsrc'
|
||||||
|
case mbs of
|
||||||
|
Nothing -> return ""
|
||||||
|
Just "" -> recv
|
||||||
|
Just bs -> return bs
|
||||||
|
|
||||||
|
let send x = C.yield x C.$$ sink
|
||||||
|
let handshake = fromByteString "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "
|
||||||
|
<> fromByteString (B64.encode key')
|
||||||
|
<> fromByteString "\r\n\r\n"
|
||||||
|
key' = SHA1.hash $ key <> "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
|
||||||
|
toByteStringIO send handshake
|
||||||
|
|
||||||
|
let msg = "This is a test"
|
||||||
|
toByteStringIO send $ wsDataToBuilder $ Frame True OpText Nothing $ fromIntegral $ S.length msg
|
||||||
|
toByteStringIO send $ wsDataToBuilder $ Payload $ fromByteString msg
|
||||||
|
|
||||||
|
src <- mkSource recv
|
||||||
|
|
||||||
|
let recv front0 = waitForFrame src $ \isFinished opcode _ _ getBS -> do
|
||||||
|
let loop front = do
|
||||||
|
bs <- getBS
|
||||||
|
if S.null bs
|
||||||
|
then return front
|
||||||
|
else loop $ front . (bs:)
|
||||||
|
front <- loop front0
|
||||||
|
if isFinished
|
||||||
|
then return $ S.concat $ front []
|
||||||
|
else recv front
|
||||||
|
app Connection
|
||||||
|
{ connSend = \isText payload -> do
|
||||||
|
toByteStringIO send $ wsDataToBuilder $ Frame True (if isText then OpText else OpBinary) Nothing $ fromIntegral $ S.length payload
|
||||||
|
send payload
|
||||||
|
, connRecv = recv id
|
||||||
|
}
|
||||||
|
| otherwise = Nothing
|
||||||
|
where
|
||||||
|
reqhs = requestHeaders req
|
||||||
|
|
||||||
|
type FrameFinished = Bool
|
||||||
|
type MaskingKey = Word32
|
||||||
|
type PayloadSize = Word64
|
||||||
|
|
||||||
|
data WSData payload
|
||||||
|
= Frame FrameFinished Opcode (Maybe MaskingKey) PayloadSize
|
||||||
|
| Payload payload
|
||||||
|
deriving Show
|
||||||
|
|
||||||
|
data Opcode = OpCont | OpText | OpBinary | OpClose | OpPing | OpPong
|
||||||
|
deriving (Show, Eq, Ord, Enum, Bounded)
|
||||||
|
|
||||||
|
opcodeToWord8 :: Opcode -> Word8
|
||||||
|
opcodeToWord8 OpCont = 0x0
|
||||||
|
opcodeToWord8 OpText = 0x1
|
||||||
|
opcodeToWord8 OpBinary = 0x2
|
||||||
|
opcodeToWord8 OpClose = 0x8
|
||||||
|
opcodeToWord8 OpPing = 0x9
|
||||||
|
opcodeToWord8 OpPong = 0xA
|
||||||
|
|
||||||
|
opcodeFromWord8 :: Word8 -> Maybe Opcode
|
||||||
|
opcodeFromWord8 =
|
||||||
|
flip Map.lookup m
|
||||||
|
where
|
||||||
|
m = Map.fromList $ map (\o -> (opcodeToWord8 o, o)) [minBound..maxBound]
|
||||||
|
|
||||||
|
wsDataToBuilder :: WSData Builder -> Builder
|
||||||
|
wsDataToBuilder (Payload builder) = builder
|
||||||
|
wsDataToBuilder (Frame finished opcode mmask payload) =
|
||||||
|
fromWord8 byte1
|
||||||
|
<> fromWord8 byte2
|
||||||
|
<> lenrest
|
||||||
|
<> maybe mempty fromWord32be mmask
|
||||||
|
where
|
||||||
|
byte1 = (if finished then 128 else 0) .|. opcodeToWord8 opcode
|
||||||
|
byte2 = (if isJust mmask then 128 else 0) .|. len1
|
||||||
|
|
||||||
|
(len1, lenrest)
|
||||||
|
| payload <= 125 = (fromIntegral payload, mempty)
|
||||||
|
| payload <= 65536 = (126, fromWord16be $ fromIntegral payload)
|
||||||
|
| otherwise = (127, fromWord64be $ fromIntegral payload)
|
||||||
|
|
||||||
|
data WSException = ConnectionClosed
|
||||||
|
| RSVBitsSet Word8
|
||||||
|
| InvalidOpcode Word8
|
||||||
|
deriving (Show, Typeable)
|
||||||
|
instance Exception WSException
|
||||||
|
|
||||||
|
data Source = Source (IO ByteString) (IORef ByteString)
|
||||||
|
|
||||||
|
mkSource :: IO ByteString -> IO Source
|
||||||
|
mkSource recv = Source recv <$> newIORef S.empty
|
||||||
|
|
||||||
|
-- | Guaranteed to never return an empty ByteString.
|
||||||
|
getBS :: Source -> IO ByteString
|
||||||
|
getBS (Source next ref) = do
|
||||||
|
bs <- readIORef ref
|
||||||
|
if S.null bs
|
||||||
|
then do
|
||||||
|
bs <- next
|
||||||
|
when (S.null bs) (throwIO ConnectionClosed)
|
||||||
|
return bs
|
||||||
|
else writeIORef ref S.empty >> return bs
|
||||||
|
|
||||||
|
leftover :: Source -> ByteString -> IO ()
|
||||||
|
leftover (Source _ ref) bs = writeIORef ref bs
|
||||||
|
|
||||||
|
getWord8 :: Source -> IO Word8
|
||||||
|
getWord8 src = do
|
||||||
|
bs <- getBS src
|
||||||
|
leftover src $ S.tail bs
|
||||||
|
return $ S.head bs
|
||||||
|
|
||||||
|
getBytes :: (Num word, Bits word) => Source -> Int -> IO word
|
||||||
|
getBytes src =
|
||||||
|
loop 0
|
||||||
|
where
|
||||||
|
loop total 0 = return total
|
||||||
|
loop total remaining = do
|
||||||
|
x <- getWord8 src -- FIXME not very efficient, better to use ByteString directly
|
||||||
|
loop (shiftL total 8 .|. fromIntegral x) (remaining - 1)
|
||||||
|
|
||||||
|
waitForFrame :: Source -> (FrameFinished -> Opcode -> Maybe MaskingKey -> PayloadSize -> IO ByteString -> IO a) -> IO a
|
||||||
|
waitForFrame src yield = do
|
||||||
|
byte1 <- getWord8 src
|
||||||
|
byte2 <- getWord8 src
|
||||||
|
|
||||||
|
when (testBit byte1 6 || testBit byte1 5 || testBit byte1 4)
|
||||||
|
$ throwIO $ RSVBitsSet byte1
|
||||||
|
|
||||||
|
let opcode' = byte1 .&. 0x0F
|
||||||
|
opcode <-
|
||||||
|
case opcodeFromWord8 opcode' of
|
||||||
|
Nothing -> throwIO $ InvalidOpcode opcode'
|
||||||
|
Just o -> return o
|
||||||
|
|
||||||
|
let isFinished = testBit byte1 7
|
||||||
|
isMasked = testBit byte2 7
|
||||||
|
len' = byte2 `clearBit` 7
|
||||||
|
|
||||||
|
payloadSize <-
|
||||||
|
case () of
|
||||||
|
()
|
||||||
|
| len' <= 125 -> return $ fromIntegral len'
|
||||||
|
| len' == 126 -> getBytes src 2
|
||||||
|
| assert (len' == 127) otherwise -> getBytes src 8
|
||||||
|
|
||||||
|
mmask <- if isMasked then Just <$> getBytes src 4 else return Nothing
|
||||||
|
let unmask' =
|
||||||
|
case mmask of
|
||||||
|
Nothing -> \_ bs -> bs
|
||||||
|
Just mask -> unmask mask
|
||||||
|
|
||||||
|
consumedRef <- newIORef 0
|
||||||
|
let getPayload = handlePayload unmask' payloadSize consumedRef
|
||||||
|
|
||||||
|
res <- yield isFinished opcode mmask payloadSize getPayload
|
||||||
|
let drain = do
|
||||||
|
bs <- getPayload
|
||||||
|
unless (S.null bs) drain
|
||||||
|
drain
|
||||||
|
return res
|
||||||
|
|
||||||
|
where
|
||||||
|
handlePayload unmask' totalSize consumedRef = do
|
||||||
|
consumed <- readIORef consumedRef
|
||||||
|
if consumed >= totalSize
|
||||||
|
then return S.empty
|
||||||
|
else do
|
||||||
|
bs <- getBS src
|
||||||
|
let len = fromIntegral $ S.length bs
|
||||||
|
consumed' = consumed + len
|
||||||
|
if consumed' <= totalSize
|
||||||
|
then do
|
||||||
|
writeIORef consumedRef consumed'
|
||||||
|
return $ unmask' consumed bs
|
||||||
|
else do
|
||||||
|
let (x, y) = S.splitAt (fromIntegral $ totalSize - consumed) bs
|
||||||
|
leftover src y
|
||||||
|
return $ unmask' consumed x
|
||||||
|
|
||||||
|
unmask :: MaskingKey -> Word64 -> ByteString -> ByteString
|
||||||
|
unmask key offset' masked =
|
||||||
|
-- we really want a mapWithIndex...
|
||||||
|
fst $ S.unfoldrN len f 0
|
||||||
|
where
|
||||||
|
len = S.length masked
|
||||||
|
|
||||||
|
f idx | idx >= len = Nothing
|
||||||
|
f idx = Just (getIndex idx, idx + 1)
|
||||||
|
|
||||||
|
offset = fromIntegral $ offset' `mod` 4
|
||||||
|
|
||||||
|
getIndex idx = S.index masked idx `xor` maskByte ((offset + idx) `mod` 4)
|
||||||
|
|
||||||
|
maskByte 0 = fromIntegral $ key `shiftR` 24
|
||||||
|
maskByte 1 = fromIntegral $ key `shiftR` 16
|
||||||
|
maskByte 2 = fromIntegral $ key `shiftR` 8
|
||||||
|
maskByte 3 = fromIntegral key
|
||||||
@ -9,6 +9,7 @@ module Yesod.WebSockets
|
|||||||
, sendBinaryData
|
, sendBinaryData
|
||||||
-- * Conduit API
|
-- * Conduit API
|
||||||
, sourceWS
|
, sourceWS
|
||||||
|
, sourceWSText
|
||||||
, sinkWSText
|
, sinkWSText
|
||||||
, sinkWSBinary
|
, sinkWSBinary
|
||||||
-- * Async helpers
|
-- * Async helpers
|
||||||
@ -26,14 +27,16 @@ import Control.Monad.Trans.Control (MonadBaseControl (liftBaseWith,
|
|||||||
import Control.Monad.Trans.Reader (ReaderT (ReaderT, runReaderT))
|
import Control.Monad.Trans.Reader (ReaderT (ReaderT, runReaderT))
|
||||||
import qualified Data.Conduit as C
|
import qualified Data.Conduit as C
|
||||||
import qualified Data.Conduit.List as CL
|
import qualified Data.Conduit.List as CL
|
||||||
import qualified Network.Wai.Handler.WebSockets as WaiWS
|
|
||||||
import qualified Network.WebSockets as WS
|
|
||||||
import qualified Yesod.Core as Y
|
import qualified Yesod.Core as Y
|
||||||
|
import qualified WaiWS
|
||||||
|
import Data.ByteString (ByteString)
|
||||||
|
import Data.Text (Text)
|
||||||
|
import Data.Text.Encoding (encodeUtf8, decodeUtf8)
|
||||||
|
|
||||||
-- | A transformer for a WebSockets handler.
|
-- | A transformer for a WebSockets handler.
|
||||||
--
|
--
|
||||||
-- Since 0.1.0
|
-- Since 0.1.0
|
||||||
type WebSocketsT = ReaderT WS.Connection
|
type WebSocketsT = ReaderT WaiWS.Connection
|
||||||
|
|
||||||
-- | Attempt to run a WebSockets handler. This function first checks if the
|
-- | Attempt to run a WebSockets handler. This function first checks if the
|
||||||
-- client initiated a WebSockets connection and, if so, runs the provided
|
-- client initiated a WebSockets connection and, if so, runs the provided
|
||||||
@ -45,50 +48,50 @@ type WebSocketsT = ReaderT WS.Connection
|
|||||||
webSockets :: (Y.MonadBaseControl IO m, Y.MonadHandler m) => WebSocketsT m () -> m ()
|
webSockets :: (Y.MonadBaseControl IO m, Y.MonadHandler m) => WebSocketsT m () -> m ()
|
||||||
webSockets inner = do
|
webSockets inner = do
|
||||||
req <- Y.waiRequest
|
req <- Y.waiRequest
|
||||||
when (WaiWS.isWebSocketsReq req) $
|
case WaiWS.websocketsApp req of
|
||||||
Y.sendRawResponse $ \src sink -> control $ \runInIO -> WaiWS.runWebSockets
|
Nothing -> return ()
|
||||||
WS.defaultConnectionOptions
|
Just runWebSockets -> Y.sendRawResponse $ \src sink -> control $ \runInIO -> runWebSockets src sink $ runInIO . runReaderT inner
|
||||||
(WaiWS.getRequestHead req)
|
|
||||||
(\pconn -> do
|
|
||||||
conn <- WS.acceptRequest pconn
|
|
||||||
runInIO $ runReaderT inner conn)
|
|
||||||
src
|
|
||||||
sink
|
|
||||||
|
|
||||||
-- | Receive a piece of data from the client.
|
-- | Receive a piece of data from the client.
|
||||||
--
|
--
|
||||||
-- Since 0.1.0
|
-- Since 0.1.0
|
||||||
receiveData :: (MonadIO m, WS.WebSocketsData a) => WebSocketsT m a
|
receiveData :: (MonadIO m) => WebSocketsT m ByteString
|
||||||
receiveData = ReaderT $ liftIO . WS.receiveData
|
receiveData = ReaderT $ liftIO . WaiWS.connRecv
|
||||||
|
|
||||||
-- | Send a textual messsage to the client.
|
-- | Send a textual messsage to the client.
|
||||||
--
|
--
|
||||||
-- Since 0.1.0
|
-- Since 0.1.0
|
||||||
sendTextData :: (MonadIO m, WS.WebSocketsData a) => a -> WebSocketsT m ()
|
sendTextData :: MonadIO m => Text -> WebSocketsT m ()
|
||||||
sendTextData x = ReaderT $ liftIO . flip WS.sendTextData x
|
sendTextData x = ReaderT $ \conn -> liftIO $ WaiWS.connSend conn True $ encodeUtf8 x
|
||||||
|
|
||||||
-- | Send a binary messsage to the client.
|
-- | Send a binary messsage to the client.
|
||||||
--
|
--
|
||||||
-- Since 0.1.0
|
-- Since 0.1.0
|
||||||
sendBinaryData :: (MonadIO m, WS.WebSocketsData a) => a -> WebSocketsT m ()
|
sendBinaryData :: MonadIO m => ByteString -> WebSocketsT m ()
|
||||||
sendBinaryData x = ReaderT $ liftIO . flip WS.sendBinaryData x
|
sendBinaryData x = ReaderT $ \conn -> liftIO $ WaiWS.connSend conn False x
|
||||||
|
|
||||||
-- | A @Source@ of WebSockets data from the user.
|
-- | A @Source@ of WebSockets data from the user.
|
||||||
--
|
--
|
||||||
-- Since 0.1.0
|
-- Since 0.1.0
|
||||||
sourceWS :: (MonadIO m, WS.WebSocketsData a) => C.Producer (WebSocketsT m) a
|
sourceWS :: MonadIO m => C.Producer (WebSocketsT m) ByteString
|
||||||
sourceWS = forever $ Y.lift receiveData >>= C.yield
|
sourceWS = forever $ Y.lift receiveData >>= C.yield
|
||||||
|
|
||||||
|
-- | A @Source@ of WebSockets data from the user.
|
||||||
|
--
|
||||||
|
-- Since 0.1.0
|
||||||
|
sourceWSText :: MonadIO m => C.Producer (WebSocketsT m) Text
|
||||||
|
sourceWSText = forever $ Y.lift receiveData >>= C.yield . decodeUtf8
|
||||||
|
|
||||||
-- | A @Sink@ for sending textual data to the user.
|
-- | A @Sink@ for sending textual data to the user.
|
||||||
--
|
--
|
||||||
-- Since 0.1.0
|
-- Since 0.1.0
|
||||||
sinkWSText :: (MonadIO m, WS.WebSocketsData a) => C.Consumer a (WebSocketsT m) ()
|
sinkWSText :: MonadIO m => C.Consumer Text (WebSocketsT m) ()
|
||||||
sinkWSText = CL.mapM_ sendTextData
|
sinkWSText = CL.mapM_ sendTextData
|
||||||
|
|
||||||
-- | A @Sink@ for sending binary data to the user.
|
-- | A @Sink@ for sending binary data to the user.
|
||||||
--
|
--
|
||||||
-- Since 0.1.0
|
-- Since 0.1.0
|
||||||
sinkWSBinary :: (MonadIO m, WS.WebSocketsData a) => C.Consumer a (WebSocketsT m) ()
|
sinkWSBinary :: MonadIO m => C.Consumer ByteString (WebSocketsT m) ()
|
||||||
sinkWSBinary = CL.mapM_ sendBinaryData
|
sinkWSBinary = CL.mapM_ sendBinaryData
|
||||||
|
|
||||||
-- | Generalized version of 'A.race'.
|
-- | Generalized version of 'A.race'.
|
||||||
|
|||||||
@ -1,6 +1,7 @@
|
|||||||
{-# LANGUAGE QuasiQuotes, TemplateHaskell, TypeFamilies #-}
|
{-# LANGUAGE QuasiQuotes, TemplateHaskell, TypeFamilies #-}
|
||||||
import Yesod.Core
|
import Yesod.Core
|
||||||
import Yesod.WebSockets
|
import Yesod.WebSockets
|
||||||
|
import qualified Data.Text as T
|
||||||
import qualified Data.Text.Lazy as TL
|
import qualified Data.Text.Lazy as TL
|
||||||
import Control.Monad (forever)
|
import Control.Monad (forever)
|
||||||
import Control.Monad.Trans.Reader
|
import Control.Monad.Trans.Reader
|
||||||
@ -16,16 +17,16 @@ mkYesod "App" [parseRoutes|
|
|||||||
/ HomeR GET
|
/ HomeR GET
|
||||||
|]
|
|]
|
||||||
|
|
||||||
timeSource :: MonadIO m => Source m TL.Text
|
timeSource :: MonadIO m => Source m T.Text
|
||||||
timeSource = forever $ do
|
timeSource = forever $ do
|
||||||
now <- liftIO getCurrentTime
|
now <- liftIO getCurrentTime
|
||||||
yield $ TL.pack $ show now
|
yield $ T.pack $ show now
|
||||||
liftIO $ threadDelay 5000000
|
liftIO $ threadDelay 5000000
|
||||||
|
|
||||||
getHomeR :: Handler Html
|
getHomeR :: Handler Html
|
||||||
getHomeR = do
|
getHomeR = do
|
||||||
webSockets $ race_
|
webSockets $ race_
|
||||||
(sourceWS $$ mapC TL.toUpper =$ sinkWSText)
|
(sourceWSText $$ mapC T.toUpper =$ sinkWSText)
|
||||||
(timeSource $$ sinkWSText)
|
(timeSource $$ sinkWSText)
|
||||||
defaultLayout $
|
defaultLayout $
|
||||||
toWidget
|
toWidget
|
||||||
|
|||||||
@ -16,14 +16,21 @@ cabal-version: >=1.8
|
|||||||
|
|
||||||
library
|
library
|
||||||
exposed-modules: Yesod.WebSockets
|
exposed-modules: Yesod.WebSockets
|
||||||
|
other-modules: WaiWS
|
||||||
build-depends: base >= 4.5 && < 5
|
build-depends: base >= 4.5 && < 5
|
||||||
, wai-websockets >= 2.1
|
|
||||||
, websockets >= 0.8
|
|
||||||
, transformers >= 0.2
|
, transformers >= 0.2
|
||||||
, yesod-core >= 1.2.7
|
, yesod-core >= 1.2.7
|
||||||
, monad-control >= 0.3
|
, monad-control >= 0.3
|
||||||
, conduit >= 1.0.15.1
|
, conduit >= 1.0.15.1
|
||||||
, async >= 2.0.1.5
|
, async >= 2.0.1.5
|
||||||
|
, base64-bytestring
|
||||||
|
, bytestring
|
||||||
|
, containers
|
||||||
|
, cryptohash
|
||||||
|
, blaze-builder
|
||||||
|
, http-types
|
||||||
|
, wai
|
||||||
|
, text
|
||||||
|
|
||||||
source-repository head
|
source-repository head
|
||||||
type: git
|
type: git
|
||||||
|
|||||||
Loading…
Reference in New Issue
Block a user