diff --git a/yesod-websockets/WaiWS.hs b/yesod-websockets/WaiWS.hs new file mode 100644 index 00000000..d6281a0c --- /dev/null +++ b/yesod-websockets/WaiWS.hs @@ -0,0 +1,242 @@ +{-# LANGUAGE DeriveDataTypeable #-} +{-# LANGUAGE OverloadedStrings #-} +{-# LANGUAGE RecordWildCards #-} +module WaiWS where + +import Network.Wai +import Control.Exception (Exception, throwIO, assert) +import Control.Applicative ((<$>)) +import Control.Monad (when, forever, unless) +import Data.Typeable (Typeable) +import Network.HTTP.Types (status200, status404) +import Blaze.ByteString.Builder +import Data.Monoid ((<>), mempty) +import qualified Crypto.Hash.SHA1 as SHA1 +import Debug.Trace +import Data.Word (Word8, Word32, Word64) +import Data.ByteString (ByteString) +import Data.Bits ((.|.), testBit, clearBit, shiftL, (.&.), Bits, xor, shiftR) +import qualified Data.Map as Map +import Data.Maybe (isJust) +import qualified Data.ByteString as S +import qualified Data.ByteString.Char8 as S8 +import qualified Data.ByteString.Base64 as B64 +import Data.IORef +import Data.Char (toUpper) +import qualified Data.Conduit as C + + +data Connection = Connection + { connSend :: Bool -> ByteString -> IO () + , connRecv :: IO ByteString + } + +websocketsApp :: Request -> Maybe (C.Source IO ByteString -> C.Sink ByteString IO () -> (WaiWS.Connection -> IO a) -> IO a) +websocketsApp req + -- FIXME handle keep-alive, Upgrade | lookup "connection" reqhs /= Just "Upgrade" = backup sendResponse + | lookup "upgrade" reqhs /= Just "websocket" = Nothing + | lookup "sec-websocket-version" reqhs /= Just "13" = Nothing + | Just key <- lookup "sec-websocket-key" reqhs = Just $ \src0 sink app -> do + (rsrc0, ()) <- src0 C.$$+ return () + rsrcRef <- newIORef rsrc0 + let recv = do + rsrc <- readIORef rsrcRef + (rsrc', mbs) <- rsrc C.$$++ C.await + writeIORef rsrcRef rsrc' + case mbs of + Nothing -> return "" + Just "" -> recv + Just bs -> return bs + + let send x = C.yield x C.$$ sink + let handshake = fromByteString "HTTP/1.1 101 Switching Protocols\r\nUpgrade: websocket\r\nConnection: Upgrade\r\nSec-WebSocket-Accept: " + <> fromByteString (B64.encode key') + <> fromByteString "\r\n\r\n" + key' = SHA1.hash $ key <> "258EAFA5-E914-47DA-95CA-C5AB0DC85B11" + toByteStringIO send handshake + + let msg = "This is a test" + toByteStringIO send $ wsDataToBuilder $ Frame True OpText Nothing $ fromIntegral $ S.length msg + toByteStringIO send $ wsDataToBuilder $ Payload $ fromByteString msg + + src <- mkSource recv + + let recv front0 = waitForFrame src $ \isFinished opcode _ _ getBS -> do + let loop front = do + bs <- getBS + if S.null bs + then return front + else loop $ front . (bs:) + front <- loop front0 + if isFinished + then return $ S.concat $ front [] + else recv front + app Connection + { connSend = \isText payload -> do + toByteStringIO send $ wsDataToBuilder $ Frame True (if isText then OpText else OpBinary) Nothing $ fromIntegral $ S.length payload + send payload + , connRecv = recv id + } + | otherwise = Nothing + where + reqhs = requestHeaders req + +type FrameFinished = Bool +type MaskingKey = Word32 +type PayloadSize = Word64 + +data WSData payload + = Frame FrameFinished Opcode (Maybe MaskingKey) PayloadSize + | Payload payload + deriving Show + +data Opcode = OpCont | OpText | OpBinary | OpClose | OpPing | OpPong + deriving (Show, Eq, Ord, Enum, Bounded) + +opcodeToWord8 :: Opcode -> Word8 +opcodeToWord8 OpCont = 0x0 +opcodeToWord8 OpText = 0x1 +opcodeToWord8 OpBinary = 0x2 +opcodeToWord8 OpClose = 0x8 +opcodeToWord8 OpPing = 0x9 +opcodeToWord8 OpPong = 0xA + +opcodeFromWord8 :: Word8 -> Maybe Opcode +opcodeFromWord8 = + flip Map.lookup m + where + m = Map.fromList $ map (\o -> (opcodeToWord8 o, o)) [minBound..maxBound] + +wsDataToBuilder :: WSData Builder -> Builder +wsDataToBuilder (Payload builder) = builder +wsDataToBuilder (Frame finished opcode mmask payload) = + fromWord8 byte1 + <> fromWord8 byte2 + <> lenrest + <> maybe mempty fromWord32be mmask + where + byte1 = (if finished then 128 else 0) .|. opcodeToWord8 opcode + byte2 = (if isJust mmask then 128 else 0) .|. len1 + + (len1, lenrest) + | payload <= 125 = (fromIntegral payload, mempty) + | payload <= 65536 = (126, fromWord16be $ fromIntegral payload) + | otherwise = (127, fromWord64be $ fromIntegral payload) + +data WSException = ConnectionClosed + | RSVBitsSet Word8 + | InvalidOpcode Word8 + deriving (Show, Typeable) +instance Exception WSException + +data Source = Source (IO ByteString) (IORef ByteString) + +mkSource :: IO ByteString -> IO Source +mkSource recv = Source recv <$> newIORef S.empty + +-- | Guaranteed to never return an empty ByteString. +getBS :: Source -> IO ByteString +getBS (Source next ref) = do + bs <- readIORef ref + if S.null bs + then do + bs <- next + when (S.null bs) (throwIO ConnectionClosed) + return bs + else writeIORef ref S.empty >> return bs + +leftover :: Source -> ByteString -> IO () +leftover (Source _ ref) bs = writeIORef ref bs + +getWord8 :: Source -> IO Word8 +getWord8 src = do + bs <- getBS src + leftover src $ S.tail bs + return $ S.head bs + +getBytes :: (Num word, Bits word) => Source -> Int -> IO word +getBytes src = + loop 0 + where + loop total 0 = return total + loop total remaining = do + x <- getWord8 src -- FIXME not very efficient, better to use ByteString directly + loop (shiftL total 8 .|. fromIntegral x) (remaining - 1) + +waitForFrame :: Source -> (FrameFinished -> Opcode -> Maybe MaskingKey -> PayloadSize -> IO ByteString -> IO a) -> IO a +waitForFrame src yield = do + byte1 <- getWord8 src + byte2 <- getWord8 src + + when (testBit byte1 6 || testBit byte1 5 || testBit byte1 4) + $ throwIO $ RSVBitsSet byte1 + + let opcode' = byte1 .&. 0x0F + opcode <- + case opcodeFromWord8 opcode' of + Nothing -> throwIO $ InvalidOpcode opcode' + Just o -> return o + + let isFinished = testBit byte1 7 + isMasked = testBit byte2 7 + len' = byte2 `clearBit` 7 + + payloadSize <- + case () of + () + | len' <= 125 -> return $ fromIntegral len' + | len' == 126 -> getBytes src 2 + | assert (len' == 127) otherwise -> getBytes src 8 + + mmask <- if isMasked then Just <$> getBytes src 4 else return Nothing + let unmask' = + case mmask of + Nothing -> \_ bs -> bs + Just mask -> unmask mask + + consumedRef <- newIORef 0 + let getPayload = handlePayload unmask' payloadSize consumedRef + + res <- yield isFinished opcode mmask payloadSize getPayload + let drain = do + bs <- getPayload + unless (S.null bs) drain + drain + return res + + where + handlePayload unmask' totalSize consumedRef = do + consumed <- readIORef consumedRef + if consumed >= totalSize + then return S.empty + else do + bs <- getBS src + let len = fromIntegral $ S.length bs + consumed' = consumed + len + if consumed' <= totalSize + then do + writeIORef consumedRef consumed' + return $ unmask' consumed bs + else do + let (x, y) = S.splitAt (fromIntegral $ totalSize - consumed) bs + leftover src y + return $ unmask' consumed x + +unmask :: MaskingKey -> Word64 -> ByteString -> ByteString +unmask key offset' masked = + -- we really want a mapWithIndex... + fst $ S.unfoldrN len f 0 + where + len = S.length masked + + f idx | idx >= len = Nothing + f idx = Just (getIndex idx, idx + 1) + + offset = fromIntegral $ offset' `mod` 4 + + getIndex idx = S.index masked idx `xor` maskByte ((offset + idx) `mod` 4) + + maskByte 0 = fromIntegral $ key `shiftR` 24 + maskByte 1 = fromIntegral $ key `shiftR` 16 + maskByte 2 = fromIntegral $ key `shiftR` 8 + maskByte 3 = fromIntegral key diff --git a/yesod-websockets/Yesod/WebSockets.hs b/yesod-websockets/Yesod/WebSockets.hs index eebe9202..4390e7cf 100644 --- a/yesod-websockets/Yesod/WebSockets.hs +++ b/yesod-websockets/Yesod/WebSockets.hs @@ -9,6 +9,7 @@ module Yesod.WebSockets , sendBinaryData -- * Conduit API , sourceWS + , sourceWSText , sinkWSText , sinkWSBinary -- * Async helpers @@ -26,14 +27,16 @@ import Control.Monad.Trans.Control (MonadBaseControl (liftBaseWith, import Control.Monad.Trans.Reader (ReaderT (ReaderT, runReaderT)) import qualified Data.Conduit as C import qualified Data.Conduit.List as CL -import qualified Network.Wai.Handler.WebSockets as WaiWS -import qualified Network.WebSockets as WS import qualified Yesod.Core as Y +import qualified WaiWS +import Data.ByteString (ByteString) +import Data.Text (Text) +import Data.Text.Encoding (encodeUtf8, decodeUtf8) -- | A transformer for a WebSockets handler. -- -- Since 0.1.0 -type WebSocketsT = ReaderT WS.Connection +type WebSocketsT = ReaderT WaiWS.Connection -- | Attempt to run a WebSockets handler. This function first checks if the -- client initiated a WebSockets connection and, if so, runs the provided @@ -45,50 +48,50 @@ type WebSocketsT = ReaderT WS.Connection webSockets :: (Y.MonadBaseControl IO m, Y.MonadHandler m) => WebSocketsT m () -> m () webSockets inner = do req <- Y.waiRequest - when (WaiWS.isWebSocketsReq req) $ - Y.sendRawResponse $ \src sink -> control $ \runInIO -> WaiWS.runWebSockets - WS.defaultConnectionOptions - (WaiWS.getRequestHead req) - (\pconn -> do - conn <- WS.acceptRequest pconn - runInIO $ runReaderT inner conn) - src - sink + case WaiWS.websocketsApp req of + Nothing -> return () + Just runWebSockets -> Y.sendRawResponse $ \src sink -> control $ \runInIO -> runWebSockets src sink $ runInIO . runReaderT inner -- | Receive a piece of data from the client. -- -- Since 0.1.0 -receiveData :: (MonadIO m, WS.WebSocketsData a) => WebSocketsT m a -receiveData = ReaderT $ liftIO . WS.receiveData +receiveData :: (MonadIO m) => WebSocketsT m ByteString +receiveData = ReaderT $ liftIO . WaiWS.connRecv -- | Send a textual messsage to the client. -- -- Since 0.1.0 -sendTextData :: (MonadIO m, WS.WebSocketsData a) => a -> WebSocketsT m () -sendTextData x = ReaderT $ liftIO . flip WS.sendTextData x +sendTextData :: MonadIO m => Text -> WebSocketsT m () +sendTextData x = ReaderT $ \conn -> liftIO $ WaiWS.connSend conn True $ encodeUtf8 x -- | Send a binary messsage to the client. -- -- Since 0.1.0 -sendBinaryData :: (MonadIO m, WS.WebSocketsData a) => a -> WebSocketsT m () -sendBinaryData x = ReaderT $ liftIO . flip WS.sendBinaryData x +sendBinaryData :: MonadIO m => ByteString -> WebSocketsT m () +sendBinaryData x = ReaderT $ \conn -> liftIO $ WaiWS.connSend conn False x -- | A @Source@ of WebSockets data from the user. -- -- Since 0.1.0 -sourceWS :: (MonadIO m, WS.WebSocketsData a) => C.Producer (WebSocketsT m) a +sourceWS :: MonadIO m => C.Producer (WebSocketsT m) ByteString sourceWS = forever $ Y.lift receiveData >>= C.yield +-- | A @Source@ of WebSockets data from the user. +-- +-- Since 0.1.0 +sourceWSText :: MonadIO m => C.Producer (WebSocketsT m) Text +sourceWSText = forever $ Y.lift receiveData >>= C.yield . decodeUtf8 + -- | A @Sink@ for sending textual data to the user. -- -- Since 0.1.0 -sinkWSText :: (MonadIO m, WS.WebSocketsData a) => C.Consumer a (WebSocketsT m) () +sinkWSText :: MonadIO m => C.Consumer Text (WebSocketsT m) () sinkWSText = CL.mapM_ sendTextData -- | A @Sink@ for sending binary data to the user. -- -- Since 0.1.0 -sinkWSBinary :: (MonadIO m, WS.WebSocketsData a) => C.Consumer a (WebSocketsT m) () +sinkWSBinary :: MonadIO m => C.Consumer ByteString (WebSocketsT m) () sinkWSBinary = CL.mapM_ sendBinaryData -- | Generalized version of 'A.race'. diff --git a/yesod-websockets/sample.hs b/yesod-websockets/sample.hs index e369a99e..9f27c57a 100644 --- a/yesod-websockets/sample.hs +++ b/yesod-websockets/sample.hs @@ -1,6 +1,7 @@ {-# LANGUAGE QuasiQuotes, TemplateHaskell, TypeFamilies #-} import Yesod.Core import Yesod.WebSockets +import qualified Data.Text as T import qualified Data.Text.Lazy as TL import Control.Monad (forever) import Control.Monad.Trans.Reader @@ -16,16 +17,16 @@ mkYesod "App" [parseRoutes| / HomeR GET |] -timeSource :: MonadIO m => Source m TL.Text +timeSource :: MonadIO m => Source m T.Text timeSource = forever $ do now <- liftIO getCurrentTime - yield $ TL.pack $ show now + yield $ T.pack $ show now liftIO $ threadDelay 5000000 getHomeR :: Handler Html getHomeR = do webSockets $ race_ - (sourceWS $$ mapC TL.toUpper =$ sinkWSText) + (sourceWSText $$ mapC T.toUpper =$ sinkWSText) (timeSource $$ sinkWSText) defaultLayout $ toWidget diff --git a/yesod-websockets/yesod-websockets.cabal b/yesod-websockets/yesod-websockets.cabal index 49fb58b5..65dcf621 100644 --- a/yesod-websockets/yesod-websockets.cabal +++ b/yesod-websockets/yesod-websockets.cabal @@ -16,14 +16,21 @@ cabal-version: >=1.8 library exposed-modules: Yesod.WebSockets + other-modules: WaiWS build-depends: base >= 4.5 && < 5 - , wai-websockets >= 2.1 - , websockets >= 0.8 , transformers >= 0.2 , yesod-core >= 1.2.7 , monad-control >= 0.3 , conduit >= 1.0.15.1 , async >= 2.0.1.5 + , base64-bytestring + , bytestring + , containers + , cryptohash + , blaze-builder + , http-types + , wai + , text source-repository head type: git