[Salsa] use more ByteArray operations instead of bytestring.
This commit is contained in:
parent
87e2862eaa
commit
b497737ef1
@ -13,13 +13,13 @@ module Crypto.Cipher.Salsa
|
||||
, State
|
||||
) where
|
||||
|
||||
import Data.SecureMem
|
||||
import Data.ByteString (ByteString)
|
||||
import qualified Data.ByteString.Internal as B
|
||||
import qualified Data.ByteString as B
|
||||
import Crypto.Internal.ByteArray (ByteArrayAccess, ByteArray, SecureBytes)
|
||||
import qualified Crypto.Internal.ByteArray as B
|
||||
import qualified Data.ByteString.Internal as BS
|
||||
import qualified Data.ByteString as BS
|
||||
import Crypto.Internal.Compat
|
||||
import Crypto.Internal.Imports
|
||||
import Data.Byteable
|
||||
import Data.Bits (xor)
|
||||
import Foreign.Ptr
|
||||
import Foreign.ForeignPtr
|
||||
@ -27,9 +27,9 @@ import Foreign.C.Types
|
||||
import Foreign.Storable
|
||||
|
||||
-- | Salsa context
|
||||
data State = State Int -- number of rounds
|
||||
SecureMem -- Salsa's state
|
||||
ByteString -- previous generated chunk
|
||||
data State = State Int -- number of rounds
|
||||
SecureBytes -- Salsa's state
|
||||
ByteString -- previous generated chunk
|
||||
|
||||
round64 :: Int -> (Bool, Int)
|
||||
round64 len
|
||||
@ -40,36 +40,37 @@ round64 len
|
||||
|
||||
-- | Initialize a new Salsa context with the number of rounds,
|
||||
-- the key and the nonce associated.
|
||||
initialize :: Byteable key
|
||||
=> Int -- ^ number of rounds (8,12,20)
|
||||
-> key -- ^ the key (128 or 256 bits)
|
||||
-> ByteString -- ^ the nonce (64 or 96 bits)
|
||||
-> State -- ^ the initial Salsa state
|
||||
initialize :: (ByteArrayAccess key, ByteArray nonce)
|
||||
=> Int -- ^ number of rounds (8,12,20)
|
||||
-> key -- ^ the key (128 or 256 bits)
|
||||
-> nonce -- ^ the nonce (64 or 96 bits)
|
||||
-> State -- ^ the initial Salsa state
|
||||
initialize nbRounds key nonce
|
||||
| not (kLen `elem` [16,32]) = error "Salsa: key length should be 128 or 256 bits"
|
||||
| not (nonceLen `elem` [8,12]) = error "Salsa: nonce length should be 64 or 96 bits"
|
||||
| not (nbRounds `elem` [8,12,20]) = error "Salsa: rounds should be 8, 12 or 20"
|
||||
| otherwise = unsafeDoIO $ do
|
||||
stPtr <- createSecureMem 64 $ \stPtr ->
|
||||
withBytePtr nonce $ \noncePtr ->
|
||||
withBytePtr key $ \keyPtr ->
|
||||
ccryptonite_salsa_init (castPtr stPtr) kLen keyPtr nonceLen noncePtr
|
||||
stPtr <- B.alloc 64 $ \stPtr ->
|
||||
B.withByteArray nonce $ \noncePtr ->
|
||||
B.withByteArray key $ \keyPtr ->
|
||||
ccryptonite_salsa_init stPtr kLen keyPtr nonceLen noncePtr
|
||||
return $ State nbRounds stPtr B.empty
|
||||
where kLen = byteableLength key
|
||||
where kLen = B.length key
|
||||
nonceLen = B.length nonce
|
||||
|
||||
-- | Combine the salsa output and an arbitrary message with a xor,
|
||||
-- and return the combined output and the new state.
|
||||
combine :: State -- ^ the current Salsa state
|
||||
-> ByteString -- ^ the source to xor with the generator
|
||||
-> (ByteString, State)
|
||||
combine :: ByteArray ba
|
||||
=> State -- ^ the current Salsa state
|
||||
-> ba -- ^ the source to xor with the generator
|
||||
-> (ba, State)
|
||||
combine prev@(State nbRounds prevSt prevOut) src
|
||||
| outputLen == 0 = (B.empty, prev)
|
||||
| outputLen <= prevBufLen =
|
||||
-- we have enough byte in the previous buffer to complete the query
|
||||
-- without having to generate any extra bytes
|
||||
let (b1,b2) = B.splitAt outputLen prevOut
|
||||
in (B.pack $ B.zipWith xor b1 src, State nbRounds prevSt b2)
|
||||
let (b1,b2) = BS.splitAt outputLen prevOut
|
||||
in (B.convert $ BS.pack $ BS.zipWith xor b1 (B.convert src), State nbRounds prevSt b2)
|
||||
| otherwise = unsafeDoIO $ do
|
||||
-- adjusted len is the number of bytes lefts to generate after
|
||||
-- copying from the previous buffer.
|
||||
@ -77,25 +78,23 @@ combine prev@(State nbRounds prevSt prevOut) src
|
||||
(roundedAlready, newBytesToGenerate) = round64 adjustedLen
|
||||
nextBufLen = newBytesToGenerate - adjustedLen
|
||||
|
||||
fptr <- B.mallocByteString (newBytesToGenerate + prevBufLen)
|
||||
fptr <- BS.mallocByteString (newBytesToGenerate + prevBufLen)
|
||||
newSt <- withForeignPtr fptr $ \dstPtr ->
|
||||
withBytePtr src $ \srcPtr -> do
|
||||
B.withByteArray src $ \srcPtr -> do
|
||||
-- copy the previous buffer by xor if any
|
||||
withBytePtr prevOut $ \prevPtr ->
|
||||
B.withByteArray prevOut $ \prevPtr ->
|
||||
loopXor dstPtr srcPtr prevPtr prevBufLen
|
||||
|
||||
-- then create a new mutable copy of state
|
||||
st <- secureMemCopy prevSt
|
||||
withSecureMemPtr st $ \stPtr ->
|
||||
B.copy prevSt $ \stPtr ->
|
||||
ccryptonite_salsa_combine nbRounds
|
||||
(dstPtr `plusPtr` prevBufLen)
|
||||
(castPtr stPtr)
|
||||
(srcPtr `plusPtr` prevBufLen)
|
||||
(fromIntegral newBytesToGenerate)
|
||||
return st
|
||||
-- return combined byte
|
||||
return ( B.PS fptr 0 outputLen
|
||||
, State nbRounds newSt (if roundedAlready then B.empty else B.PS fptr outputLen nextBufLen))
|
||||
return ( B.convert (BS.PS fptr 0 outputLen)
|
||||
, State nbRounds newSt (if roundedAlready then BS.empty else BS.PS fptr outputLen nextBufLen))
|
||||
where
|
||||
outputLen = B.length src
|
||||
prevBufLen = B.length prevOut
|
||||
@ -109,10 +108,11 @@ combine prev@(State nbRounds prevSt prevOut) src
|
||||
-- | Generate a number of bytes from the Salsa output directly
|
||||
--
|
||||
-- TODO: use salsa_generate directly instead of using combine xor'ing with 0.
|
||||
generate :: State -- ^ the current Salsa state
|
||||
generate :: ByteArray ba
|
||||
=> State -- ^ the current Salsa state
|
||||
-> Int -- ^ the length of data to generate
|
||||
-> (ByteString, State)
|
||||
generate st len = combine st (B.replicate len 0)
|
||||
-> (ba, State)
|
||||
generate st len = combine st (B.zero len)
|
||||
|
||||
foreign import ccall "cryptonite_salsa_init"
|
||||
ccryptonite_salsa_init :: Ptr State -> Int -> Ptr Word8 -> Int -> Ptr Word8 -> IO ()
|
||||
|
||||
Loading…
Reference in New Issue
Block a user