use securemem abstraction and byteable helpers in RC4

This commit is contained in:
Vincent Hanquez 2014-07-21 11:17:42 +01:00
parent f2bfecfa3e
commit 9c9007c3b0

View File

@ -22,14 +22,13 @@ module Crypto.Cipher.RC4
import Data.Word import Data.Word
import Data.Byteable import Data.Byteable
import Data.SecureMem
import Foreign.Ptr import Foreign.Ptr
import Foreign.ForeignPtr import Foreign.ForeignPtr
import System.IO.Unsafe import System.IO.Unsafe
import Data.Byteable
import Data.ByteString (ByteString) import Data.ByteString (ByteString)
import qualified Data.ByteString as B import qualified Data.ByteString as B
import qualified Data.ByteString.Internal as B import qualified Data.ByteString.Internal as B
import Control.Applicative ((<$>))
---------------------------------------------------------------------- ----------------------------------------------------------------------
unsafeDoIO :: IO a -> a unsafeDoIO :: IO a -> a
@ -40,13 +39,13 @@ unsafeDoIO = unsafePerformIO
#endif #endif
-- | The encryption state for RC4 -- | The encryption state for RC4
newtype State = State ByteString newtype State = State SecureMem
-- | C Call for initializing the encryptor -- | C Call for initializing the encryptor
foreign import ccall unsafe "cryptonite_rc4.h cryptonite_rc4_init" foreign import ccall unsafe "cryptonite_rc4.h cryptonite_rc4_init"
c_rc4_init :: Ptr Word8 -- ^ The rc4 key c_rc4_init :: Ptr Word8 -- ^ The rc4 key
-> Word32 -- ^ The key length -> Word32 -- ^ The key length
-> Ptr State -- ^ The context -> Ptr State -- ^ The context
-> IO () -> IO ()
foreign import ccall unsafe "cryptonite_rc4.h cryptonite_rc4_combine" foreign import ccall unsafe "cryptonite_rc4.h cryptonite_rc4_combine"
@ -56,10 +55,6 @@ foreign import ccall unsafe "cryptonite_rc4.h cryptonite_rc4_combine"
-> Ptr Word8 -- ^ Output buffer -> Ptr Word8 -- ^ Output buffer
-> IO () -> IO ()
withByteStringPtr :: ByteString -> (Ptr Word8 -> IO a) -> IO a
withByteStringPtr b f = withForeignPtr fptr $ \ptr -> f (ptr `plusPtr` off)
where (fptr, off, _) = B.toForeignPtr b
-- | RC4 context initialization. -- | RC4 context initialization.
-- --
-- seed the context with an initial key. the key size need to be -- seed the context with an initial key. the key size need to be
@ -68,7 +63,9 @@ initialize :: Byteable key
=> key -- ^ The key => key -- ^ The key
-> State -- ^ The RC4 context with the key mixed in -> State -- ^ The RC4 context with the key mixed in
initialize key = unsafeDoIO $ do initialize key = unsafeDoIO $ do
State <$> (B.create 264 $ \ctx -> withBytePtr key $ \keyPtr -> c_rc4_init (castPtr keyPtr) (fromIntegral $ byteableLength key) (castPtr ctx)) st <- createSecureMem 264 $ \stPtr ->
withBytePtr key $ \keyPtr -> c_rc4_init keyPtr (fromIntegral $ byteableLength key) (castPtr stPtr)
return $ State st
-- | generate the next len bytes of the rc4 stream without combining -- | generate the next len bytes of the rc4 stream without combining
-- it to anything. -- it to anything.
@ -79,14 +76,12 @@ generate ctx len = combine ctx (B.replicate len 0)
combine :: State -- ^ rc4 context combine :: State -- ^ rc4 context
-> ByteString -- ^ input -> ByteString -- ^ input
-> (State, ByteString) -- ^ new rc4 context, and the output -> (State, ByteString) -- ^ new rc4 context, and the output
combine (State cctx) clearText = unsafeDoIO $ combine (State prevSt) clearText = unsafeDoIO $ do
B.mallocByteString 264 >>= \dctx -> outfptr <- B.mallocByteString len
B.mallocByteString len >>= \outfptr -> st <- secureMemCopy prevSt
withByteStringPtr clearText $ \clearPtr -> withSecureMemPtr st $ \stPtr ->
withByteStringPtr cctx $ \srcState -> withForeignPtr outfptr $ \outptr ->
withForeignPtr dctx $ \dstState -> do withBytePtr clearText $ \clearPtr ->
withForeignPtr outfptr $ \outptr -> do c_rc4_combine (castPtr stPtr) clearPtr (fromIntegral len) outptr
B.memcpy dstState srcState 264 return $! (State st, B.PS outfptr 0 len)
c_rc4_combine (castPtr dstState) clearPtr (fromIntegral len) outptr where len = B.length clearText
return $! (State $! B.PS dctx 0 264, B.PS outfptr 0 len)
where len = B.length clearText