diff --git a/src/Database/Esqueleto/Internal/Sql.hs b/src/Database/Esqueleto/Internal/Sql.hs index 1c73b85..205a462 100644 --- a/src/Database/Esqueleto/Internal/Sql.hs +++ b/src/Database/Esqueleto/Internal/Sql.hs @@ -35,6 +35,9 @@ module Database.Esqueleto.Internal.Sql , rawEsqueleto , toRawSql , Mode(..) + , IdentState + , initialIdentState + , IdentInfo , SqlSelect , veryUnsafeCoerceSqlExprValue ) where @@ -219,9 +222,13 @@ newIdentFor = Q . lift . try . unDBName return (I t) +-- | Information needed to escape and use identifiers. +type IdentInfo = (Connection, IdentState) + + -- | Use an identifier. -useIdent :: Connection -> Ident -> TLB.Builder -useIdent conn (I ident) = fromDBName conn $ DBName ident +useIdent :: IdentInfo -> Ident -> TLB.Builder +useIdent info (I ident) = fromDBName info $ DBName ident ---------------------------------------------------------------------- @@ -240,7 +247,7 @@ data SqlExpr a where -- connection (mainly for escaping names) and returns both an -- string ('TLB.Builder') and a list of values to be -- interpolated by the SQL backend. - ERaw :: NeedParens -> (Connection -> (TLB.Builder, [PersistValue])) -> SqlExpr (Value a) + ERaw :: NeedParens -> (IdentInfo -> (TLB.Builder, [PersistValue])) -> SqlExpr (Value a) -- | 'EList' and 'EEmptyList' are used by list operators. EList :: SqlExpr (Value a) -> SqlExpr (ValueList a) @@ -256,7 +263,7 @@ data SqlExpr a where EPreprocessedFrom :: a -> FromClause -> SqlExpr (PreprocessedFrom a) -- | Used by 'insertSelect'. - EInsert :: Proxy a -> (Connection -> (TLB.Builder, [PersistValue])) -> SqlExpr (Insertion a) + EInsert :: Proxy a -> (IdentInfo -> (TLB.Builder, [PersistValue])) -> SqlExpr (Insertion a) data NeedParens = Parens | Never @@ -317,7 +324,7 @@ instance Esqueleto SqlQuery SqlExpr SqlBackend where sub_selectDistinct = sub SELECT_DISTINCT EEntity ident ^. field = - ERaw Never $ \conn -> (useIdent conn ident <> ("." <> fieldName conn field), []) + ERaw Never $ \info -> (useIdent info ident <> ("." <> fieldName info field), []) EMaybe r ?. field = maybelize (r ^. field) where @@ -331,10 +338,10 @@ instance Esqueleto SqlQuery SqlExpr SqlBackend where nothing = unsafeSqlValue "NULL" joinV (ERaw p f) = ERaw p f countRows = unsafeSqlValue "COUNT(*)" - count (ERaw _ f) = ERaw Never $ \conn -> let (b, vals) = f conn + count (ERaw _ f) = ERaw Never $ \info -> let (b, vals) = f info in ("COUNT" <> parens b, vals) - not_ (ERaw p f) = ERaw Never $ \conn -> let (b, vals) = f conn + not_ (ERaw p f) = ERaw Never $ \info -> let (b, vals) = f info in ("NOT " <> parensM p b, vals) (==.) = unsafeSqlBinOp " = " @@ -399,21 +406,21 @@ instance ToSomeValues SqlExpr (SqlExpr (Value a)) where toSomeValues a = [SomeValue a] fieldName :: (PersistEntity val, PersistField typ) - => Connection -> EntityField val typ -> TLB.Builder -fieldName conn = fromDBName conn . fieldDB . persistFieldDef + => IdentInfo -> EntityField val typ -> TLB.Builder +fieldName info = fromDBName info . fieldDB . persistFieldDef setAux :: (PersistEntity val, PersistField typ) => EntityField val typ -> (SqlExpr (Entity val) -> SqlExpr (Value typ)) -> SqlExpr (Update val) setAux field mkVal = ESet $ \ent -> unsafeSqlBinOp " = " name (mkVal ent) - where name = ERaw Never $ \conn -> (fieldName conn field, mempty) + where name = ERaw Never $ \info -> (fieldName info field, mempty) sub :: PersistField a => Mode -> SqlQuery (SqlExpr (Value a)) -> SqlExpr (Value a) -sub mode query = ERaw Parens $ \conn -> toRawSql mode pureQuery conn query +sub mode query = ERaw Parens $ \info -> toRawSql mode pureQuery info query -fromDBName :: Connection -> DBName -> TLB.Builder -fromDBName conn = TLB.fromText . connEscapeName conn +fromDBName :: IdentInfo -> DBName -> TLB.Builder +fromDBName (conn, _) = TLB.fromText . connEscapeName conn existsHelper :: SqlQuery () -> SqlExpr (Value Bool) existsHelper = sub SELECT . (>> return true) @@ -444,8 +451,8 @@ ifNotEmptyList (EList _) _ x = x unsafeSqlBinOp :: TLB.Builder -> SqlExpr (Value a) -> SqlExpr (Value b) -> SqlExpr (Value c) unsafeSqlBinOp op (ERaw p1 f1) (ERaw p2 f2) = ERaw Parens f where - f conn = let (b1, vals1) = f1 conn - (b2, vals2) = f2 conn + f info = let (b1, vals1) = f1 info + (b2, vals2) = f2 info in ( parensM p1 b1 <> op <> parensM p2 b2 , vals1 <> vals2 ) {-# INLINE unsafeSqlBinOp #-} @@ -463,9 +470,9 @@ unsafeSqlValue v = ERaw Never $ \_ -> (v, mempty) unsafeSqlFunction :: UnsafeSqlFunctionArgument a => TLB.Builder -> a -> SqlExpr (Value b) unsafeSqlFunction name arg = - ERaw Never $ \conn -> + ERaw Never $ \info -> let (argsTLB, argsVals) = - uncommas' $ map (\(ERaw _ f) -> f conn) $ toArgList arg + uncommas' $ map (\(ERaw _ f) -> f info) $ toArgList arg in (name <> parens argsTLB, argsVals) class UnsafeSqlFunctionArgument a where @@ -527,7 +534,7 @@ rawSelectSource mode query = src run conn = uncurry rawQuery $ first builderToText $ - toRawSql mode pureQuery conn query + toRawSql mode pureQuery (conn, initialIdentState) query massage = do mrow <- C.await @@ -639,7 +646,7 @@ rawEsqueleto mode query = do conn <- SqlPersistT R.ask uncurry rawExecuteCount $ first builderToText $ - toRawSql mode pureQuery conn query + toRawSql mode pureQuery (conn, initialIdentState) query -- | Execute an @esqueleto@ @DELETE@ query inside @persistent@'s @@ -723,24 +730,37 @@ builderToText = TL.toStrict . TLB.toLazyTextWith defaultChunkSize -- @esqueleto@, instead of manually using this function (which is -- possible but tedious), you may just turn on query logging of -- @persistent@. -toRawSql :: SqlSelect a r => Mode -> QueryType a -> Connection -> SqlQuery a -> (TLB.Builder, [PersistValue]) -toRawSql mode qt conn query = - let (ret, SideData fromClauses setClauses whereClauses groupByClause havingClause orderByClauses limitClause) = - flip S.evalState initialIdentState $ +toRawSql :: SqlSelect a r => Mode -> QueryType a -> IdentInfo -> SqlQuery a -> (TLB.Builder, [PersistValue]) +toRawSql mode qt (conn, firstIdentState) query = + let ((ret, sd), finalIdentState) = + flip S.runState firstIdentState $ W.runWriterT $ unQ query + SideData fromClauses + setClauses + whereClauses + groupByClause + havingClause + orderByClauses + limitClause = sd + -- Pass the finalIdentState (containing all identifiers + -- that were used) to the subsequent calls. This ensures + -- that no name clashes will occur on subqueries that may + -- appear on the expressions below. + info = (conn, finalIdentState) in mconcat [ makeInsert qt ret - , makeSelect conn mode ret - , makeFrom conn mode fromClauses - , makeSet conn setClauses - , makeWhere conn whereClauses - , makeGroupBy conn groupByClause - , makeHaving conn havingClause - , makeOrderBy conn orderByClauses - , makeLimit conn limitClause + , makeSelect info mode ret + , makeFrom info mode fromClauses + , makeSet info setClauses + , makeWhere info whereClauses + , makeGroupBy info groupByClause + , makeHaving info havingClause + , makeOrderBy info orderByClauses + , makeLimit info limitClause ] + -- | (Internal) Mode of query being converted by 'toRawSql'. data Mode = SELECT | SELECT_DISTINCT | DELETE | UPDATE @@ -767,21 +787,21 @@ uncommas' :: Monoid a => [(TLB.Builder, a)] -> (TLB.Builder, a) uncommas' = (uncommas *** mconcat) . unzip -makeSelect :: SqlSelect a r => Connection -> Mode -> a -> (TLB.Builder, [PersistValue]) -makeSelect conn mode ret = +makeSelect :: SqlSelect a r => IdentInfo -> Mode -> a -> (TLB.Builder, [PersistValue]) +makeSelect info mode ret = case mode of SELECT -> withCols "SELECT " SELECT_DISTINCT -> withCols "SELECT DISTINCT " DELETE -> plain "DELETE " UPDATE -> plain "UPDATE " where - withCols v = first (v <>) (sqlSelectCols conn ret) + withCols v = first (v <>) (sqlSelectCols info ret) plain v = (v, []) -makeFrom :: Connection -> Mode -> [FromClause] -> (TLB.Builder, [PersistValue]) +makeFrom :: IdentInfo -> Mode -> [FromClause] -> (TLB.Builder, [PersistValue]) makeFrom _ _ [] = mempty -makeFrom conn mode fs = ret +makeFrom info mode fs = ret where ret = case collectOnClauses fs of Left expr -> throw $ mkExc expr @@ -802,8 +822,8 @@ makeFrom conn mode fs = ret base ident@(I identText) def = let db@(DBName dbText) = entityDB def in ( if dbText == identText - then fromDBName conn db - else fromDBName conn db <> (" AS " <> useIdent conn ident) + then fromDBName info db + else fromDBName info db <> (" AS " <> useIdent info ident) , mempty ) fromKind InnerJoinKind = " INNER JOIN " @@ -812,56 +832,56 @@ makeFrom conn mode fs = ret fromKind RightOuterJoinKind = " RIGHT OUTER JOIN " fromKind FullOuterJoinKind = " FULL OUTER JOIN " - makeOnClause (ERaw _ f) = first (" ON " <>) (f conn) + makeOnClause (ERaw _ f) = first (" ON " <>) (f info) mkExc :: SqlExpr (Value Bool) -> OnClauseWithoutMatchingJoinException mkExc (ERaw _ f) = OnClauseWithoutMatchingJoinException $ - TL.unpack $ TLB.toLazyText $ fst (f conn) + TL.unpack $ TLB.toLazyText $ fst (f info) -makeSet :: Connection -> [SetClause] -> (TLB.Builder, [PersistValue]) +makeSet :: IdentInfo -> [SetClause] -> (TLB.Builder, [PersistValue]) makeSet _ [] = mempty -makeSet conn os = first ("\nSET " <>) $ uncommas' (map mk os) +makeSet info os = first ("\nSET " <>) $ uncommas' (map mk os) where - mk (SetClause (ERaw _ f)) = f conn + mk (SetClause (ERaw _ f)) = f info -makeWhere :: Connection -> WhereClause -> (TLB.Builder, [PersistValue]) +makeWhere :: IdentInfo -> WhereClause -> (TLB.Builder, [PersistValue]) makeWhere _ NoWhere = mempty -makeWhere conn (Where (ERaw _ f)) = first ("\nWHERE " <>) (f conn) +makeWhere info (Where (ERaw _ f)) = first ("\nWHERE " <>) (f info) -makeGroupBy :: Connection -> GroupByClause -> (TLB.Builder, [PersistValue]) +makeGroupBy :: IdentInfo -> GroupByClause -> (TLB.Builder, [PersistValue]) makeGroupBy _ (GroupBy []) = (mempty, []) -makeGroupBy conn (GroupBy fields) = first ("\nGROUP BY " <>) build +makeGroupBy info (GroupBy fields) = first ("\nGROUP BY " <>) build where - build = uncommas' $ map (\(SomeValue (ERaw _ f)) -> f conn) fields + build = uncommas' $ map (\(SomeValue (ERaw _ f)) -> f info) fields -makeHaving :: Connection -> WhereClause -> (TLB.Builder, [PersistValue]) +makeHaving :: IdentInfo -> WhereClause -> (TLB.Builder, [PersistValue]) makeHaving _ NoWhere = mempty -makeHaving conn (Where (ERaw _ f)) = first ("\nHAVING " <>) (f conn) +makeHaving info (Where (ERaw _ f)) = first ("\nHAVING " <>) (f info) -makeOrderBy :: Connection -> [OrderByClause] -> (TLB.Builder, [PersistValue]) +makeOrderBy :: IdentInfo -> [OrderByClause] -> (TLB.Builder, [PersistValue]) makeOrderBy _ [] = mempty -makeOrderBy conn os = first ("\nORDER BY " <>) $ uncommas' (map mk os) +makeOrderBy info os = first ("\nORDER BY " <>) $ uncommas' (map mk os) where - mk (EOrderBy t (ERaw p f)) = first ((<> orderByType t) . parensM p) (f conn) + mk (EOrderBy t (ERaw p f)) = first ((<> orderByType t) . parensM p) (f info) orderByType ASC = " ASC" orderByType DESC = " DESC" -makeLimit :: Connection -> LimitClause -> (TLB.Builder, [PersistValue]) +makeLimit :: IdentInfo -> LimitClause -> (TLB.Builder, [PersistValue]) makeLimit _ (Limit Nothing Nothing) = mempty makeLimit _ (Limit Nothing (Just 0)) = mempty -makeLimit conn (Limit ml mo) = (ret, mempty) +makeLimit info (Limit ml mo) = (ret, mempty) where ret = TLB.singleton '\n' <> (limitTLB <> offsetTLB) limitTLB = case ml of Just l -> "LIMIT " <> TLBI.decimal l - Nothing -> TLB.fromText (connNoLimit conn) + Nothing -> TLB.fromText (connNoLimit $ fst info) offsetTLB = case mo of @@ -886,7 +906,7 @@ class SqlSelect a r | a -> r, r -> a where -- | Creates the variable part of the @SELECT@ query and -- returns the list of 'PersistValue's that will be given to -- 'rawQuery'. - sqlSelectCols :: Connection -> a -> (TLB.Builder, [PersistValue]) + sqlSelectCols :: IdentInfo -> a -> (TLB.Builder, [PersistValue]) -- | Number of columns that will be consumed. sqlSelectColCount :: Proxy a -> Int @@ -897,7 +917,7 @@ class SqlSelect a r | a -> r, r -> a where -- | You may return an insertion of some PersistEntity instance PersistEntity a => SqlSelect (SqlExpr (Insertion a)) (Insertion a) where - sqlSelectCols conn (EInsert _ f) = f conn + sqlSelectCols info (EInsert _ f) = f info sqlSelectColCount = const 0 sqlSelectProcessRow = const (Right (error msg)) where @@ -913,10 +933,10 @@ instance SqlSelect () () where -- | You may return an 'Entity' from a 'select' query. instance PersistEntity a => SqlSelect (SqlExpr (Entity a)) (Entity a) where - sqlSelectCols conn expr@(EEntity ident) = ret + sqlSelectCols info expr@(EEntity ident) = ret where process ed = uncommas $ - map ((name <>) . fromDBName conn) $ + map ((name <>) . fromDBName info) $ (entityID ed:) $ map fieldDB $ entityFields ed @@ -926,7 +946,7 @@ instance PersistEntity a => SqlSelect (SqlExpr (Entity a)) (Entity a) where -- clause), while 'rawSql' assumes that it's just the -- name of the table (which doesn't allow self-joins, for -- example). - name = useIdent conn ident <> "." + name = useIdent info ident <> "." ret = let ed = entityDef $ getEntityVal $ return expr in (process ed, mempty) sqlSelectColCount = (+1) . length . entityFields . entityDef . getEntityVal @@ -941,7 +961,7 @@ getEntityVal = const Proxy -- | You may return a possibly-@NULL@ 'Entity' from a 'select' query. instance PersistEntity a => SqlSelect (SqlExpr (Maybe (Entity a))) (Maybe (Entity a)) where - sqlSelectCols conn (EMaybe ent) = sqlSelectCols conn ent + sqlSelectCols info (EMaybe ent) = sqlSelectCols info ent sqlSelectColCount = sqlSelectColCount . fromEMaybe where fromEMaybe :: Proxy (SqlExpr (Maybe e)) -> Proxy (SqlExpr e) @@ -954,8 +974,8 @@ instance PersistEntity a => SqlSelect (SqlExpr (Maybe (Entity a))) (Maybe (Entit -- | You may return any single value (i.e. a single column) from -- a 'select' query. instance PersistField a => SqlSelect (SqlExpr (Value a)) (Value a) where - sqlSelectCols esc (ERaw p f) = let (b, vals) = f esc - in (parensM p b, vals) + sqlSelectCols info (ERaw p f) = let (b, vals) = f info + in (parensM p b, vals) sqlSelectColCount = const 1 sqlSelectProcessRow [pv] = Value <$> fromPersistValue pv sqlSelectProcessRow _ = Left "SqlSelect (Value a): wrong number of columns." @@ -1468,4 +1488,4 @@ insertGeneralSelect :: (MonadLogger m, MonadResourceBase m, SqlSelect (SqlExpr ( Mode -> SqlQuery (SqlExpr (Insertion a)) -> SqlPersistT m () insertGeneralSelect mode query = do conn <- SqlPersistT R.ask - uncurry rawExecute $ first builderToText $ toRawSql mode insertQuery conn query + uncurry rawExecute $ first builderToText $ toRawSql mode insertQuery (conn, initialIdentState) query