1
1
from __future__ import unicode_literals
2
- from jinja2 import contextfilter
3
2
from jinja2 import Environment
4
3
from jinja2 import Template
5
4
from jinja2 .ext import Extension
6
5
from jinja2 .lexer import Token
7
- from jinja2 .utils import Markup
6
+ from jinja2 .utils import markupsafe
8
7
from collections .abc import Iterable
9
8
10
9
try :
@@ -96,7 +95,7 @@ def filter_stream(self, stream):
96
95
def sql_safe (value ):
97
96
"""Filter to mark the value of an expression as safe for inserting
98
97
in a SQL statement"""
99
- return Markup (value )
98
+ return markupsafe . Markup (value )
100
99
101
100
def identifier (value ):
102
101
"""A filter that escapes a SQL identifier, usually database objects
@@ -117,7 +116,7 @@ def escape_postgres(tuple_or_str):
117
116
values = (tuple_or_str , ) if not isinstance (tuple_or_str , tuple ) else tuple_or_str
118
117
def escape_double_quotes (value ):
119
118
return value .replace ('"' , '""' )
120
- return Markup ('.' .join ('"{}"' .format (escape_double_quotes (value )) for value in values ))
119
+ return markupsafe . Markup ('.' .join ('"{}"' .format (escape_double_quotes (value )) for value in values ))
121
120
122
121
def bind (value , name ):
123
122
"""A filter that prints %s, and stores the value
@@ -126,7 +125,7 @@ def bind(value, name):
126
125
This filter is automatically applied to every {{variable}}
127
126
during the lexing stage, so developers can't forget to bind
128
127
"""
129
- if isinstance (value , Markup ):
128
+ if isinstance (value , markupsafe . Markup ):
130
129
return value
131
130
else :
132
131
return _bind_param (_thread_local .bind_params , name , value )
@@ -175,7 +174,7 @@ def identifier_filter(raw_identifier):
175
174
raw_identifier = (raw_identifier , )
176
175
if not isinstance (raw_identifier , Iterable ):
177
176
raise ValueError ("identifier filter expects a string or an Iterable" )
178
- return Markup ('.' .join (quote_and_escape (s ) for s in raw_identifier ))
177
+ return markupsafe . Markup ('.' .join (quote_and_escape (s ) for s in raw_identifier ))
179
178
180
179
return identifier_filter
181
180
@@ -195,12 +194,14 @@ class JinjaSql(object):
195
194
# asyncpg "where name = $1"
196
195
VALID_PARAM_STYLES = ('qmark' , 'numeric' , 'named' , 'format' , 'pyformat' , 'asyncpg' )
197
196
VALID_ID_QUOTE_CHARS = ('`' , '"' )
198
- def __init__ (self , env = None , param_style = 'format ' , db_engine = 'postgres' , identifier_quote_character = '"' ):
197
+ def __init__ (self , env = None , param_style = 'named ' , db_engine = 'postgres' , identifier_quote_character = '"' ):
199
198
# self.env = env or Environment()
200
199
# self._prepare_environment()
201
200
self .param_style = param_style
202
201
if identifier_quote_character not in self .VALID_ID_QUOTE_CHARS :
203
- raise ValueError ("identifier_quote_characters must be one of " + VALID_ID_QUOTE_CHARS )
202
+ raise ValueError (
203
+ f"identifier_quote_characters must be one of { JinjaSql .VALID_ID_QUOTE_CHARS } "
204
+ )
204
205
self .identifier_quote_character = identifier_quote_character
205
206
self .db_engine = db_engine
206
207
self .env = env or Environment ()
@@ -209,7 +210,6 @@ def __init__(self, env=None, param_style='format', db_engine='postgres', identif
209
210
def _prepare_environment (self ):
210
211
self .env .autoescape = True
211
212
self .env .add_extension (SqlExtension )
212
- self .env .add_extension ('jinja2.ext.autoescape' )
213
213
self .env .filters ["bind" ] = bind
214
214
self .env .filters ["sqlsafe" ] = sql_safe
215
215
self .env .filters ["inclause" ] = bind_in_clause
@@ -234,7 +234,7 @@ def _prepare_query(self, template, data):
234
234
if self .param_style in ('named' , 'pyformat' ):
235
235
bind_params = dict (bind_params )
236
236
elif self .param_style in ('qmark' , 'numeric' , 'format' , 'asyncpg' ):
237
- bind_params = list (bind_params .values ())
237
+ bind_params = tuple (bind_params .values ())
238
238
return query , bind_params
239
239
finally :
240
240
del _thread_local .bind_params
0 commit comments