diff --git a/Crypto/ECC.hs b/Crypto/ECC.hs index 52890df..990105d 100644 --- a/Crypto/ECC.hs +++ b/Crypto/ECC.hs @@ -9,6 +9,7 @@ -- {-# LANGUAGE GeneralizedNewtypeDeriving #-} {-# LANGUAGE TypeFamilies #-} +{-# LANGUAGE ScopedTypeVariables #-} module Crypto.ECC ( Curve_P256R1(..) , Curve_P384R1(..) @@ -26,13 +27,12 @@ import qualified Crypto.PubKey.ECC.Types as H import qualified Crypto.PubKey.ECC.Prim as H import Crypto.Random import Crypto.Internal.Imports -import Crypto.Internal.ByteArray (ByteArrayAccess, ScrubbedBytes) +import Crypto.Internal.ByteArray (ByteArray, ByteArrayAccess, ScrubbedBytes) +import qualified Crypto.Internal.ByteArray as B import Crypto.Number.Serialize (i2ospOf_, os2ip) import qualified Crypto.PubKey.Curve25519 as X25519 import Data.Function (on) import Data.ByteArray (convert) -import Data.ByteString (ByteString) -import qualified Data.ByteString as B -- | An elliptic curve key pair composed of the private part (a scalar), and -- the associated point. @@ -70,8 +70,8 @@ class EllipticCurve curve where -- | Generate a new random keypair curveGenerateKeyPair :: MonadRandom randomly => randomly (KeyPair curve) - encodePoint :: Point curve -> ByteString - decodePoint :: ByteString -> Point curve + encodePoint :: ByteArray bs => Point curve -> bs + decodePoint :: ByteArray bs => bs -> Point curve instance {-# OVERLAPPABLE #-} Show (Point a) where show _ = undefined @@ -219,18 +219,22 @@ instance EllipticCurveDH Curve_X25519 where where secret = X25519.dh p s -encodeECPoint :: Integer -> Integer -> Int -> ByteString +encodeECPoint :: forall bs. ByteArray bs => Integer -> Integer -> Int -> bs encodeECPoint x y siz = B.concat [uncompressed,xb,yb] where + uncompressed, xb, yb :: bs uncompressed = B.singleton 4 xb = i2ospOf_ siz x yb = i2ospOf_ siz y -decodeECPoint :: ByteString -> (Integer,Integer) -decodeECPoint mxy = (x,y) - where - xy = B.drop 1 mxy -- dropping 4 (uncompressed) - siz = B.length xy `div` 2 - (xb,yb) = B.splitAt siz xy - x = os2ip xb - y = os2ip yb +decodeECPoint :: ByteArray bs => bs -> (Integer,Integer) +decodeECPoint mxy = case B.uncons mxy of + Nothing -> error "decodeECPoint" + Just (m,xy) + -- uncompressed + | m == 4 -> let siz = B.length xy `div` 2 + (xb,yb) = B.splitAt siz xy + x = os2ip xb + y = os2ip yb + in (x,y) + | otherwise -> error $ "decodeECPoint: unknown " ++ show m