Compare commits

...

1 Commits

Author SHA1 Message Date
Michael Snoyman
66437453f5 First stab at a native WebSockets implementation 2014-05-26 16:11:07 +03:00
4 changed files with 279 additions and 26 deletions

242
yesod-websockets/WaiWS.hs Normal file
View File

@ -0,0 +1,242 @@
{-# LANGUAGE DeriveDataTypeable #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE RecordWildCards #-}
module WaiWS where
import Network.Wai
import Control.Exception (Exception, throwIO, assert)
import Control.Applicative ((<$>))
import Control.Monad (when, forever, unless)
import Data.Typeable (Typeable)
import Network.HTTP.Types (status200, status404)
import Blaze.ByteString.Builder
import Data.Monoid ((<>), mempty)
import qualified Crypto.Hash.SHA1 as SHA1
import Debug.Trace
import Data.Word (Word8, Word32, Word64)
import Data.ByteString (ByteString)
import Data.Bits ((.|.), testBit, clearBit, shiftL, (.&.), Bits, xor, shiftR)
import qualified Data.Map as Map
import Data.Maybe (isJust)
import qualified Data.ByteString as S
import qualified Data.ByteString.Char8 as S8
import qualified Data.ByteString.Base64 as B64
import Data.IORef
import Data.Char (toUpper)
import qualified Data.Conduit as C
data Connection = Connection
{ connSend :: Bool -> ByteString -> IO ()
, connRecv :: IO ByteString
}
websocketsApp :: Request -> Maybe (C.Source IO ByteString -> C.Sink ByteString IO () -> (WaiWS.Connection -> IO a) -> IO a)
websocketsApp req
-- FIXME handle keep-alive, Upgrade | lookup "connection" reqhs /= Just "Upgrade" = backup sendResponse
| lookup "upgrade" reqhs /= Just "websocket" = Nothing
| lookup "sec-websocket-version" reqhs /= Just "13" = Nothing
| Just key <- lookup "sec-websocket-key" reqhs = Just $ \src0 sink app -> do
(rsrc0, ()) <- src0 C.$$+ return ()
rsrcRef <- newIORef rsrc0
let recv = do
rsrc <- readIORef rsrcRef
(rsrc', mbs) <- rsrc C.$$++ C.await
writeIORef rsrcRef rsrc'
case mbs of
Nothing -> return ""
Just "" -> recv
Just bs -> return bs
let send x = C.yield x C.$$ sink
let handshake = fromByteString "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: "
<> fromByteString (B64.encode key')
<> fromByteString "\r\n\r\n"
key' = SHA1.hash $ key <> "258EAFA5-E914-47DA-95CA-C5AB0DC85B11"
toByteStringIO send handshake
let msg = "This is a test"
toByteStringIO send $ wsDataToBuilder $ Frame True OpText Nothing $ fromIntegral $ S.length msg
toByteStringIO send $ wsDataToBuilder $ Payload $ fromByteString msg
src <- mkSource recv
let recv front0 = waitForFrame src $ \isFinished opcode _ _ getBS -> do
let loop front = do
bs <- getBS
if S.null bs
then return front
else loop $ front . (bs:)
front <- loop front0
if isFinished
then return $ S.concat $ front []
else recv front
app Connection
{ connSend = \isText payload -> do
toByteStringIO send $ wsDataToBuilder $ Frame True (if isText then OpText else OpBinary) Nothing $ fromIntegral $ S.length payload
send payload
, connRecv = recv id
}
| otherwise = Nothing
where
reqhs = requestHeaders req
type FrameFinished = Bool
type MaskingKey = Word32
type PayloadSize = Word64
data WSData payload
= Frame FrameFinished Opcode (Maybe MaskingKey) PayloadSize
| Payload payload
deriving Show
data Opcode = OpCont | OpText | OpBinary | OpClose | OpPing | OpPong
deriving (Show, Eq, Ord, Enum, Bounded)
opcodeToWord8 :: Opcode -> Word8
opcodeToWord8 OpCont = 0x0
opcodeToWord8 OpText = 0x1
opcodeToWord8 OpBinary = 0x2
opcodeToWord8 OpClose = 0x8
opcodeToWord8 OpPing = 0x9
opcodeToWord8 OpPong = 0xA
opcodeFromWord8 :: Word8 -> Maybe Opcode
opcodeFromWord8 =
flip Map.lookup m
where
m = Map.fromList $ map (\o -> (opcodeToWord8 o, o)) [minBound..maxBound]
wsDataToBuilder :: WSData Builder -> Builder
wsDataToBuilder (Payload builder) = builder
wsDataToBuilder (Frame finished opcode mmask payload) =
fromWord8 byte1
<> fromWord8 byte2
<> lenrest
<> maybe mempty fromWord32be mmask
where
byte1 = (if finished then 128 else 0) .|. opcodeToWord8 opcode
byte2 = (if isJust mmask then 128 else 0) .|. len1
(len1, lenrest)
| payload <= 125 = (fromIntegral payload, mempty)
| payload <= 65536 = (126, fromWord16be $ fromIntegral payload)
| otherwise = (127, fromWord64be $ fromIntegral payload)
data WSException = ConnectionClosed
| RSVBitsSet Word8
| InvalidOpcode Word8
deriving (Show, Typeable)
instance Exception WSException
data Source = Source (IO ByteString) (IORef ByteString)
mkSource :: IO ByteString -> IO Source
mkSource recv = Source recv <$> newIORef S.empty
-- | Guaranteed to never return an empty ByteString.
getBS :: Source -> IO ByteString
getBS (Source next ref) = do
bs <- readIORef ref
if S.null bs
then do
bs <- next
when (S.null bs) (throwIO ConnectionClosed)
return bs
else writeIORef ref S.empty >> return bs
leftover :: Source -> ByteString -> IO ()
leftover (Source _ ref) bs = writeIORef ref bs
getWord8 :: Source -> IO Word8
getWord8 src = do
bs <- getBS src
leftover src $ S.tail bs
return $ S.head bs
getBytes :: (Num word, Bits word) => Source -> Int -> IO word
getBytes src =
loop 0
where
loop total 0 = return total
loop total remaining = do
x <- getWord8 src -- FIXME not very efficient, better to use ByteString directly
loop (shiftL total 8 .|. fromIntegral x) (remaining - 1)
waitForFrame :: Source -> (FrameFinished -> Opcode -> Maybe MaskingKey -> PayloadSize -> IO ByteString -> IO a) -> IO a
waitForFrame src yield = do
byte1 <- getWord8 src
byte2 <- getWord8 src
when (testBit byte1 6 || testBit byte1 5 || testBit byte1 4)
$ throwIO $ RSVBitsSet byte1
let opcode' = byte1 .&. 0x0F
opcode <-
case opcodeFromWord8 opcode' of
Nothing -> throwIO $ InvalidOpcode opcode'
Just o -> return o
let isFinished = testBit byte1 7
isMasked = testBit byte2 7
len' = byte2 `clearBit` 7
payloadSize <-
case () of
()
| len' <= 125 -> return $ fromIntegral len'
| len' == 126 -> getBytes src 2
| assert (len' == 127) otherwise -> getBytes src 8
mmask <- if isMasked then Just <$> getBytes src 4 else return Nothing
let unmask' =
case mmask of
Nothing -> \_ bs -> bs
Just mask -> unmask mask
consumedRef <- newIORef 0
let getPayload = handlePayload unmask' payloadSize consumedRef
res <- yield isFinished opcode mmask payloadSize getPayload
let drain = do
bs <- getPayload
unless (S.null bs) drain
drain
return res
where
handlePayload unmask' totalSize consumedRef = do
consumed <- readIORef consumedRef
if consumed >= totalSize
then return S.empty
else do
bs <- getBS src
let len = fromIntegral $ S.length bs
consumed' = consumed + len
if consumed' <= totalSize
then do
writeIORef consumedRef consumed'
return $ unmask' consumed bs
else do
let (x, y) = S.splitAt (fromIntegral $ totalSize - consumed) bs
leftover src y
return $ unmask' consumed x
unmask :: MaskingKey -> Word64 -> ByteString -> ByteString
unmask key offset' masked =
-- we really want a mapWithIndex...
fst $ S.unfoldrN len f 0
where
len = S.length masked
f idx | idx >= len = Nothing
f idx = Just (getIndex idx, idx + 1)
offset = fromIntegral $ offset' `mod` 4
getIndex idx = S.index masked idx `xor` maskByte ((offset + idx) `mod` 4)
maskByte 0 = fromIntegral $ key `shiftR` 24
maskByte 1 = fromIntegral $ key `shiftR` 16
maskByte 2 = fromIntegral $ key `shiftR` 8
maskByte 3 = fromIntegral key

View File

@ -9,6 +9,7 @@ module Yesod.WebSockets
, sendBinaryData , sendBinaryData
-- * Conduit API -- * Conduit API
, sourceWS , sourceWS
, sourceWSText
, sinkWSText , sinkWSText
, sinkWSBinary , sinkWSBinary
-- * Async helpers -- * Async helpers
@ -26,14 +27,16 @@ import Control.Monad.Trans.Control (MonadBaseControl (liftBaseWith,
import Control.Monad.Trans.Reader (ReaderT (ReaderT, runReaderT)) import Control.Monad.Trans.Reader (ReaderT (ReaderT, runReaderT))
import qualified Data.Conduit as C import qualified Data.Conduit as C
import qualified Data.Conduit.List as CL import qualified Data.Conduit.List as CL
import qualified Network.Wai.Handler.WebSockets as WaiWS
import qualified Network.WebSockets as WS
import qualified Yesod.Core as Y import qualified Yesod.Core as Y
import qualified WaiWS
import Data.ByteString (ByteString)
import Data.Text (Text)
import Data.Text.Encoding (encodeUtf8, decodeUtf8)
-- | A transformer for a WebSockets handler. -- | A transformer for a WebSockets handler.
-- --
-- Since 0.1.0 -- Since 0.1.0
type WebSocketsT = ReaderT WS.Connection type WebSocketsT = ReaderT WaiWS.Connection
-- | Attempt to run a WebSockets handler. This function first checks if the -- | Attempt to run a WebSockets handler. This function first checks if the
-- client initiated a WebSockets connection and, if so, runs the provided -- client initiated a WebSockets connection and, if so, runs the provided
@ -45,50 +48,50 @@ type WebSocketsT = ReaderT WS.Connection
webSockets :: (Y.MonadBaseControl IO m, Y.MonadHandler m) => WebSocketsT m () -> m () webSockets :: (Y.MonadBaseControl IO m, Y.MonadHandler m) => WebSocketsT m () -> m ()
webSockets inner = do webSockets inner = do
req <- Y.waiRequest req <- Y.waiRequest
when (WaiWS.isWebSocketsReq req) $ case WaiWS.websocketsApp req of
Y.sendRawResponse $ \src sink -> control $ \runInIO -> WaiWS.runWebSockets Nothing -> return ()
WS.defaultConnectionOptions Just runWebSockets -> Y.sendRawResponse $ \src sink -> control $ \runInIO -> runWebSockets src sink $ runInIO . runReaderT inner
(WaiWS.getRequestHead req)
(\pconn -> do
conn <- WS.acceptRequest pconn
runInIO $ runReaderT inner conn)
src
sink
-- | Receive a piece of data from the client. -- | Receive a piece of data from the client.
-- --
-- Since 0.1.0 -- Since 0.1.0
receiveData :: (MonadIO m, WS.WebSocketsData a) => WebSocketsT m a receiveData :: (MonadIO m) => WebSocketsT m ByteString
receiveData = ReaderT $ liftIO . WS.receiveData receiveData = ReaderT $ liftIO . WaiWS.connRecv
-- | Send a textual messsage to the client. -- | Send a textual messsage to the client.
-- --
-- Since 0.1.0 -- Since 0.1.0
sendTextData :: (MonadIO m, WS.WebSocketsData a) => a -> WebSocketsT m () sendTextData :: MonadIO m => Text -> WebSocketsT m ()
sendTextData x = ReaderT $ liftIO . flip WS.sendTextData x sendTextData x = ReaderT $ \conn -> liftIO $ WaiWS.connSend conn True $ encodeUtf8 x
-- | Send a binary messsage to the client. -- | Send a binary messsage to the client.
-- --
-- Since 0.1.0 -- Since 0.1.0
sendBinaryData :: (MonadIO m, WS.WebSocketsData a) => a -> WebSocketsT m () sendBinaryData :: MonadIO m => ByteString -> WebSocketsT m ()
sendBinaryData x = ReaderT $ liftIO . flip WS.sendBinaryData x sendBinaryData x = ReaderT $ \conn -> liftIO $ WaiWS.connSend conn False x
-- | A @Source@ of WebSockets data from the user. -- | A @Source@ of WebSockets data from the user.
-- --
-- Since 0.1.0 -- Since 0.1.0
sourceWS :: (MonadIO m, WS.WebSocketsData a) => C.Producer (WebSocketsT m) a sourceWS :: MonadIO m => C.Producer (WebSocketsT m) ByteString
sourceWS = forever $ Y.lift receiveData >>= C.yield sourceWS = forever $ Y.lift receiveData >>= C.yield
-- | A @Source@ of WebSockets data from the user.
--
-- Since 0.1.0
sourceWSText :: MonadIO m => C.Producer (WebSocketsT m) Text
sourceWSText = forever $ Y.lift receiveData >>= C.yield . decodeUtf8
-- | A @Sink@ for sending textual data to the user. -- | A @Sink@ for sending textual data to the user.
-- --
-- Since 0.1.0 -- Since 0.1.0
sinkWSText :: (MonadIO m, WS.WebSocketsData a) => C.Consumer a (WebSocketsT m) () sinkWSText :: MonadIO m => C.Consumer Text (WebSocketsT m) ()
sinkWSText = CL.mapM_ sendTextData sinkWSText = CL.mapM_ sendTextData
-- | A @Sink@ for sending binary data to the user. -- | A @Sink@ for sending binary data to the user.
-- --
-- Since 0.1.0 -- Since 0.1.0
sinkWSBinary :: (MonadIO m, WS.WebSocketsData a) => C.Consumer a (WebSocketsT m) () sinkWSBinary :: MonadIO m => C.Consumer ByteString (WebSocketsT m) ()
sinkWSBinary = CL.mapM_ sendBinaryData sinkWSBinary = CL.mapM_ sendBinaryData
-- | Generalized version of 'A.race'. -- | Generalized version of 'A.race'.

View File

@ -1,6 +1,7 @@
{-# LANGUAGE QuasiQuotes, TemplateHaskell, TypeFamilies #-} {-# LANGUAGE QuasiQuotes, TemplateHaskell, TypeFamilies #-}
import Yesod.Core import Yesod.Core
import Yesod.WebSockets import Yesod.WebSockets
import qualified Data.Text as T
import qualified Data.Text.Lazy as TL import qualified Data.Text.Lazy as TL
import Control.Monad (forever) import Control.Monad (forever)
import Control.Monad.Trans.Reader import Control.Monad.Trans.Reader
@ -16,16 +17,16 @@ mkYesod "App" [parseRoutes|
/ HomeR GET / HomeR GET
|] |]
timeSource :: MonadIO m => Source m TL.Text timeSource :: MonadIO m => Source m T.Text
timeSource = forever $ do timeSource = forever $ do
now <- liftIO getCurrentTime now <- liftIO getCurrentTime
yield $ TL.pack $ show now yield $ T.pack $ show now
liftIO $ threadDelay 5000000 liftIO $ threadDelay 5000000
getHomeR :: Handler Html getHomeR :: Handler Html
getHomeR = do getHomeR = do
webSockets $ race_ webSockets $ race_
(sourceWS $$ mapC TL.toUpper =$ sinkWSText) (sourceWSText $$ mapC T.toUpper =$ sinkWSText)
(timeSource $$ sinkWSText) (timeSource $$ sinkWSText)
defaultLayout $ defaultLayout $
toWidget toWidget

View File

@ -16,14 +16,21 @@ cabal-version: >=1.8
library library
exposed-modules: Yesod.WebSockets exposed-modules: Yesod.WebSockets
other-modules: WaiWS
build-depends: base >= 4.5 && < 5 build-depends: base >= 4.5 && < 5
, wai-websockets >= 2.1
, websockets >= 0.8
, transformers >= 0.2 , transformers >= 0.2
, yesod-core >= 1.2.7 , yesod-core >= 1.2.7
, monad-control >= 0.3 , monad-control >= 0.3
, conduit >= 1.0.15.1 , conduit >= 1.0.15.1
, async >= 2.0.1.5 , async >= 2.0.1.5
, base64-bytestring
, bytestring
, containers
, cryptohash
, blaze-builder
, http-types
, wai
, text
source-repository head source-repository head
type: git type: git