diff --git a/src/Database/Esqueleto/Internal/Sql.hs b/src/Database/Esqueleto/Internal/Sql.hs index c592c82..1281833 100644 --- a/src/Database/Esqueleto/Internal/Sql.hs +++ b/src/Database/Esqueleto/Internal/Sql.hs @@ -231,7 +231,7 @@ newIdentFor = Q . lift . try . unDBName -- | Information needed to escape and use identifiers. -type IdentInfo = (Connection, IdentState) +type IdentInfo = (SqlBackend, IdentState) -- | Use an identifier. @@ -257,6 +257,7 @@ data SqlExpr a where -- interpolated by the SQL backend. ERaw :: NeedParens -> (IdentInfo -> (TLB.Builder, [PersistValue])) -> SqlExpr (Value a) + -- Used to support non-id/composite primary keys ERawList :: (IdentInfo -> ([TLB.Builder], [PersistValue])) -> SqlExpr (Value a) -- 'EList' and 'EEmptyList' are used by list operators. @@ -355,6 +356,7 @@ instance Esqueleto SqlQuery SqlExpr SqlBackend where where maybelize :: SqlExpr (Value a) -> SqlExpr (Value (Maybe a)) maybelize (ERaw p f) = ERaw p f + maybelize (ERawList f) = ERawList f val v = case v' of PersistList vs -> ERawList $ const (replicate (length vs) "?", vs) @@ -363,15 +365,20 @@ instance Esqueleto SqlQuery SqlExpr SqlBackend where isNothing (ERaw p f) = ERaw Parens $ first ((<> " IS NULL") . parensM p) . f + isNothing (ERawList f) = ERaw Parens $ first (intersperseB " AND " . map (<> " IS NULL")) . f just (ERaw p f) = ERaw p f + just (ERawList f) = ERawList f nothing = unsafeSqlValue "NULL" joinV (ERaw p f) = ERaw p f + joinV (ERawList f) = ERawList f countRows = unsafeSqlValue "COUNT(*)" count (ERaw _ f) = ERaw Never $ \info -> let (b, vals) = f info in ("COUNT" <> parens b, vals) + count (ERawList _) = unsafeSqlValue "COUNT(*)" -- Assumes no NULLs on a PK not_ (ERaw p f) = ERaw Never $ \info -> let (b, vals) = f info in ("NOT " <> parensM p b, vals) + not_ (ERawList f) = ERaw Parens $ first (intersperseB " AND " . map ("NOT " <>)) . f (==.) = unsafeSqlBinOpList " = " " AND " (>=.) = unsafeSqlBinOp " >= " @@ -427,11 +434,13 @@ instance Esqueleto SqlQuery SqlExpr SqlBackend where field /=. expr = setAux field (\ent -> ent ^. field /. expr) (<#) _ (ERaw _ f) = EInsert Proxy f + (<#) _ (ERawList _) = error "<# not supported on composite" (EInsert _ f) <&> (ERaw _ g) = EInsert Proxy $ \x -> let (fb, fv) = f x (gb, gv) = g x in (fb <> ", " <> gb, fv ++ gv) + (EInsert _ _) <&> (ERawList _) = error "<&> not supported on composite" case_ = unsafeSqlCase @@ -473,23 +482,25 @@ ifNotEmptyList (EList _) _ x = x -- -- Since: 2.1.1 unsafeSqlCase :: PersistField a => [(SqlExpr (Value Bool), SqlExpr (Value a))] -> SqlExpr (Value a) -> SqlExpr (Value a) -unsafeSqlCase when_ (ERaw p1 f1) = ERaw Never buildCase +unsafeSqlCase when (ERaw p1 f1) = ERaw Never buildCase where buildCase :: IdentInfo -> (TLB.Builder, [PersistValue]) buildCase info = let (b1, vals1) = f1 info - (b2, vals2) = mapWhen when_ info + (b2, vals2) = mapWhen when info in ( "CASE" <> b2 <> " ELSE " <> parensM p1 b1 <> " END", vals2 <> vals1) mapWhen :: [(SqlExpr (Value Bool), SqlExpr (Value a))] -> IdentInfo -> (TLB.Builder, [PersistValue]) mapWhen [] _ = error "unsafeSqlCase: empty when_ list." - mapWhen when_ info = foldl (foldHelp info) (mempty, mempty) when_ + mapWhen when' info = foldl (foldHelp info) (mempty, mempty) when' foldHelp :: IdentInfo -> (TLB.Builder, [PersistValue]) -> (SqlExpr (Value Bool), SqlExpr (Value a)) -> (TLB.Builder, [PersistValue]) - foldHelp info (b0, vals0) (ERaw p1 f1, ERaw p2 f2) = - let (b1, vals1) = f1 info + foldHelp info (b0, vals0) (ERaw p1' f1', ERaw p2 f2) = + let (b1, vals1) = f1' info (b2, vals2) = f2 info - in ( b0 <> " WHEN " <> parensM p1 b1 <> " THEN " <> parensM p2 b2, vals0 <> vals1 <> vals2 ) + in ( b0 <> " WHEN " <> parensM p1' b1 <> " THEN " <> parensM p2 b2, vals0 <> vals1 <> vals2 ) + foldHelp _ _ _ = error "non-id/composite PKs not supported on cae constructs" +unsafeSqlCase _ (ERawList _) = error "non-id/composite PKs not supported on cae constructs" -- | (Internal) Create a custom binary operator. You /should/ @@ -511,6 +522,8 @@ unsafeSqlBinOp op (ERaw p1 f1) (ERaw p2 f2) = ERaw Parens f (b2, vals2) = f2 info in ( parensM p1 b1 <> op <> parensM p2 b2 , vals1 <> vals2 ) +unsafeSqlBinOp op _ _ = error . TL.unpack . TLB.toLazyText $ + "Operator '" <> op <> "' not supported on non-id/composite primary keys" {-# INLINE unsafeSqlBinOp #-} unsafeSqlBinOpList :: TLB.Builder -> TLB.Builder -> SqlExpr (Value a) -> SqlExpr (Value b) -> SqlExpr (Value c) @@ -590,6 +603,7 @@ instance ( UnsafeSqlFunctionArgument a -- unless you know what you're doing! veryUnsafeCoerceSqlExprValue :: SqlExpr (Value a) -> SqlExpr (Value b) veryUnsafeCoerceSqlExprValue (ERaw p f) = ERaw p f +veryUnsafeCoerceSqlExprValue (ERawList f) = ERawList f -- | (Internal) Coerce a value's type from 'SqlExpr (ValueList