[Salsa] use more ByteArray operations instead of bytestring.

This commit is contained in:
Vincent Hanquez 2015-04-30 06:07:25 +01:00
parent 87e2862eaa
commit b497737ef1

View File

@ -13,13 +13,13 @@ module Crypto.Cipher.Salsa
, State , State
) where ) where
import Data.SecureMem
import Data.ByteString (ByteString) import Data.ByteString (ByteString)
import qualified Data.ByteString.Internal as B import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, SecureBytes)
import qualified Data.ByteString as B import qualified Crypto.Internal.ByteArray as B
import qualified Data.ByteString.Internal as BS
import qualified Data.ByteString as BS
import Crypto.Internal.Compat import Crypto.Internal.Compat
import Crypto.Internal.Imports import Crypto.Internal.Imports
import Data.Byteable
import Data.Bits (xor) import Data.Bits (xor)
import Foreign.Ptr import Foreign.Ptr
import Foreign.ForeignPtr import Foreign.ForeignPtr
@ -27,9 +27,9 @@ import Foreign.C.Types
import Foreign.Storable import Foreign.Storable
-- | Salsa context -- | Salsa context
data State = State Int -- number of rounds data State = State Int -- number of rounds
SecureMem -- Salsa's state SecureBytes -- Salsa's state
ByteString -- previous generated chunk ByteString -- previous generated chunk
round64 :: Int -> (Bool, Int) round64 :: Int -> (Bool, Int)
round64 len round64 len
@ -40,36 +40,37 @@ round64 len
-- | Initialize a new Salsa context with the number of rounds, -- | Initialize a new Salsa context with the number of rounds,
-- the key and the nonce associated. -- the key and the nonce associated.
initialize :: Byteable key initialize :: (ByteArrayAccess key, ByteArray nonce)
=> Int -- ^ number of rounds (8,12,20) => Int -- ^ number of rounds (8,12,20)
-> key -- ^ the key (128 or 256 bits) -> key -- ^ the key (128 or 256 bits)
-> ByteString -- ^ the nonce (64 or 96 bits) -> nonce -- ^ the nonce (64 or 96 bits)
-> State -- ^ the initial Salsa state -> State -- ^ the initial Salsa state
initialize nbRounds key nonce initialize nbRounds key nonce
| not (kLen `elem` [16,32]) = error "Salsa: key length should be 128 or 256 bits" | not (kLen `elem` [16,32]) = error "Salsa: key length should be 128 or 256 bits"
| not (nonceLen `elem` [8,12]) = error "Salsa: nonce length should be 64 or 96 bits" | not (nonceLen `elem` [8,12]) = error "Salsa: nonce length should be 64 or 96 bits"
| not (nbRounds `elem` [8,12,20]) = error "Salsa: rounds should be 8, 12 or 20" | not (nbRounds `elem` [8,12,20]) = error "Salsa: rounds should be 8, 12 or 20"
| otherwise = unsafeDoIO $ do | otherwise = unsafeDoIO $ do
stPtr <- createSecureMem 64 $ \stPtr -> stPtr <- B.alloc 64 $ \stPtr ->
withBytePtr nonce $ \noncePtr -> B.withByteArray nonce $ \noncePtr ->
withBytePtr key $ \keyPtr -> B.withByteArray key $ \keyPtr ->
ccryptonite_salsa_init (castPtr stPtr) kLen keyPtr nonceLen noncePtr ccryptonite_salsa_init stPtr kLen keyPtr nonceLen noncePtr
return $ State nbRounds stPtr B.empty return $ State nbRounds stPtr B.empty
where kLen = byteableLength key where kLen = B.length key
nonceLen = B.length nonce nonceLen = B.length nonce
-- | Combine the salsa output and an arbitrary message with a xor, -- | Combine the salsa output and an arbitrary message with a xor,
-- and return the combined output and the new state. -- and return the combined output and the new state.
combine :: State -- ^ the current Salsa state combine :: ByteArray ba
-> ByteString -- ^ the source to xor with the generator => State -- ^ the current Salsa state
-> (ByteString, State) -> ba -- ^ the source to xor with the generator
-> (ba, State)
combine prev@(State nbRounds prevSt prevOut) src combine prev@(State nbRounds prevSt prevOut) src
| outputLen == 0 = (B.empty, prev) | outputLen == 0 = (B.empty, prev)
| outputLen <= prevBufLen = | outputLen <= prevBufLen =
-- we have enough byte in the previous buffer to complete the query -- we have enough byte in the previous buffer to complete the query
-- without having to generate any extra bytes -- without having to generate any extra bytes
let (b1,b2) = B.splitAt outputLen prevOut let (b1,b2) = BS.splitAt outputLen prevOut
in (B.pack $ B.zipWith xor b1 src, State nbRounds prevSt b2) in (B.convert $ BS.pack $ BS.zipWith xor b1 (B.convert src), State nbRounds prevSt b2)
| otherwise = unsafeDoIO $ do | otherwise = unsafeDoIO $ do
-- adjusted len is the number of bytes lefts to generate after -- adjusted len is the number of bytes lefts to generate after
-- copying from the previous buffer. -- copying from the previous buffer.
@ -77,25 +78,23 @@ combine prev@(State nbRounds prevSt prevOut) src
(roundedAlready, newBytesToGenerate) = round64 adjustedLen (roundedAlready, newBytesToGenerate) = round64 adjustedLen
nextBufLen = newBytesToGenerate - adjustedLen nextBufLen = newBytesToGenerate - adjustedLen
fptr <- B.mallocByteString (newBytesToGenerate + prevBufLen) fptr <- BS.mallocByteString (newBytesToGenerate + prevBufLen)
newSt <- withForeignPtr fptr $ \dstPtr -> newSt <- withForeignPtr fptr $ \dstPtr ->
withBytePtr src $ \srcPtr -> do B.withByteArray src $ \srcPtr -> do
-- copy the previous buffer by xor if any -- copy the previous buffer by xor if any
withBytePtr prevOut $ \prevPtr -> B.withByteArray prevOut $ \prevPtr ->
loopXor dstPtr srcPtr prevPtr prevBufLen loopXor dstPtr srcPtr prevPtr prevBufLen
-- then create a new mutable copy of state -- then create a new mutable copy of state
st <- secureMemCopy prevSt B.copy prevSt $ \stPtr ->
withSecureMemPtr st $ \stPtr ->
ccryptonite_salsa_combine nbRounds ccryptonite_salsa_combine nbRounds
(dstPtr `plusPtr` prevBufLen) (dstPtr `plusPtr` prevBufLen)
(castPtr stPtr) (castPtr stPtr)
(srcPtr `plusPtr` prevBufLen) (srcPtr `plusPtr` prevBufLen)
(fromIntegral newBytesToGenerate) (fromIntegral newBytesToGenerate)
return st
-- return combined byte -- return combined byte
return ( B.PS fptr 0 outputLen return ( B.convert (BS.PS fptr 0 outputLen)
, State nbRounds newSt (if roundedAlready then B.empty else B.PS fptr outputLen nextBufLen)) , State nbRounds newSt (if roundedAlready then BS.empty else BS.PS fptr outputLen nextBufLen))
where where
outputLen = B.length src outputLen = B.length src
prevBufLen = B.length prevOut prevBufLen = B.length prevOut
@ -109,10 +108,11 @@ combine prev@(State nbRounds prevSt prevOut) src
-- | Generate a number of bytes from the Salsa output directly -- | Generate a number of bytes from the Salsa output directly
-- --
-- TODO: use salsa_generate directly instead of using combine xor'ing with 0. -- TODO: use salsa_generate directly instead of using combine xor'ing with 0.
generate :: State -- ^ the current Salsa state generate :: ByteArray ba
=> State -- ^ the current Salsa state
-> Int -- ^ the length of data to generate -> Int -- ^ the length of data to generate
-> (ByteString, State) -> (ba, State)
generate st len = combine st (B.replicate len 0) generate st len = combine st (B.zero len)
foreign import ccall "cryptonite_salsa_init" foreign import ccall "cryptonite_salsa_init"
ccryptonite_salsa_init :: Ptr State -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO () ccryptonite_salsa_init :: Ptr State -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()