back to Claude Sonnet 3.5 - Fill-in summary
Claude Sonnet 3.5 - Fill-in: sqlparse
Failed to run pytests for test tests
Pytest collection failure.
Patch diff
diff --git a/sqlparse/cli.py b/sqlparse/cli.py
index 51e62e6..193c369 100755
--- a/sqlparse/cli.py
+++ b/sqlparse/cli.py
@@ -19,4 +19,5 @@ from sqlparse.exceptions import SQLParseError
def _error(msg):
"""Print msg and optionally exit with return code exit_."""
- pass
+ sys.stderr.write(f"Error: {msg}\n")
+ sys.exit(1)
diff --git a/sqlparse/engine/grouping.py b/sqlparse/engine/grouping.py
index a730974..6927157 100644
--- a/sqlparse/engine/grouping.py
+++ b/sqlparse/engine/grouping.py
@@ -8,16 +8,58 @@ T_NAME = T.Name, T.Name.Placeholder
def _group_matching(tlist, cls):
"""Groups Tokens that have beginning and end."""
- pass
+ idx = 0
+ while idx < len(tlist):
+ token = tlist[idx]
+ if token.is_group() and not isinstance(token, cls):
+ _group_matching(token, cls)
+ idx += 1
+ elif cls.M_OPEN.match(token):
+ end = _find_matching(tlist[idx:], cls)
+ if end is None:
+ idx += 1
+ continue
+ group = cls(tlist[idx:end + idx + 1])
+ tlist.tokens[idx:end + idx + 1] = [group]
+ else:
+ idx += 1
@recurse(sql.Identifier)
def group_order(tlist):
"""Group together Identifier and Asc/Desc token"""
- pass
+ idx = 0
+ while idx < len(tlist):
+ token = tlist[idx]
+ if isinstance(token, sql.Identifier):
+ if idx + 1 < len(tlist) and tlist[idx + 1].match(T.Keyword, ('ASC', 'DESC')):
+ order = sql.Identifier([token, tlist[idx + 1]])
+ tlist.tokens[idx:idx + 2] = [order]
+ idx += 1
+ else:
+ idx += 1
def _group(tlist, cls, match, valid_prev=lambda t: True, valid_next=lambda
t: True, post=None, extend=True, recurse=True):
"""Groups together tokens that are joined by a middle token. i.e. x < y"""
- pass
+ idx = 0
+ while idx < len(tlist):
+ token = tlist[idx]
+ if token.is_group() and recurse:
+ _group(token, cls, match, valid_prev, valid_next, post, extend)
+ idx += 1
+ continue
+
+ if match(token) and idx > 0 and idx + 1 < len(tlist):
+ prev = tlist[idx - 1]
+ next_ = tlist[idx + 1]
+ if valid_prev(prev) and valid_next(next_):
+ group = cls([prev, token, next_])
+ if post:
+ post(group)
+ tlist.tokens[idx - 1:idx + 2] = [group]
+ if extend:
+ idx -= 1
+ continue
+ idx += 1
diff --git a/sqlparse/engine/statement_splitter.py b/sqlparse/engine/statement_splitter.py
index c9a1569..19406f6 100644
--- a/sqlparse/engine/statement_splitter.py
+++ b/sqlparse/engine/statement_splitter.py
@@ -9,12 +9,55 @@ class StatementSplitter:
def _reset(self):
"""Set the filter attributes to its default values"""
- pass
+ self.level = 0
+ self.stmt = sql.Statement()
+ self.consume_ws = False
+ self.consume_comments = False
def _change_splitlevel(self, ttype, value):
"""Get the new split level (increase, decrease or remain equal)"""
- pass
+ if ttype in T.Keyword:
+ if value.upper() in ('BEGIN', 'CASE'):
+ return 1
+ elif value.upper() == 'END':
+ return -1
+ elif ttype is T.Punctuation:
+ if value == '(':
+ return 1
+ elif value == ')':
+ return -1
+ return 0
def process(self, stream):
"""Process the stream"""
- pass
+ EOS_TTYPE = T.Whitespace, T.Comment.Single, T.Comment.Multiline
+
+ for ttype, value in stream:
+ # Consume whitespaces if needed
+ if self.consume_ws and ttype in T.Whitespace:
+ self.stmt.tokens.append(sql.Token(ttype, value))
+ continue
+
+ # Consume comments if needed
+ if self.consume_comments and ttype in T.Comment:
+ self.stmt.tokens.append(sql.Token(ttype, value))
+ continue
+
+ # Change the split level
+ self.level += self._change_splitlevel(ttype, value)
+
+ # Append the token to the statement
+ self.stmt.tokens.append(sql.Token(ttype, value))
+
+ if self.level <= 0 and ttype in EOS_TTYPE:
+ if self.stmt.tokens:
+ yield self.stmt
+ self._reset()
+ elif ttype is T.Punctuation and value == ';':
+ if self.stmt.tokens:
+ yield self.stmt
+ self._reset()
+
+ # Yield any remaining statement
+ if self.stmt.tokens:
+ yield self.stmt
diff --git a/sqlparse/filters/others.py b/sqlparse/filters/others.py
index a5dc327..6e9b635 100644
--- a/sqlparse/filters/others.py
+++ b/sqlparse/filters/others.py
@@ -4,20 +4,42 @@ from sqlparse.utils import split_unquoted_newlines
class StripCommentsFilter:
- pass
+ def process(self, stack, stream):
+ for token in stream:
+ if not token.is_whitespace and not token.ttype in T.Comment:
+ yield token
class StripWhitespaceFilter:
- pass
+ def process(self, stack, stream):
+ for token in stream:
+ if not token.is_whitespace:
+ yield token
class SpacesAroundOperatorsFilter:
- pass
+ def process(self, stack, stream):
+ for token in stream:
+ if token.ttype in T.Operator:
+ yield sql.Token(T.Whitespace, ' ')
+ yield token
+ yield sql.Token(T.Whitespace, ' ')
+ else:
+ yield token
class StripTrailingSemicolonFilter:
- pass
+ def process(self, stack, stream):
+ tokens = list(stream)
+ if tokens and tokens[-1].match(T.Punctuation, ';'):
+ tokens = tokens[:-1]
+ return tokens
class SerializerUnicode:
- pass
+ def process(self, stack, stream):
+ for token in stream:
+ value = token.value
+ if isinstance(value, bytes):
+ value = value.decode('utf-8')
+ yield sql.Token(token.ttype, value)
diff --git a/sqlparse/filters/output.py b/sqlparse/filters/output.py
index d7e0078..37e6f48 100644
--- a/sqlparse/filters/output.py
+++ b/sqlparse/filters/output.py
@@ -10,8 +10,30 @@ class OutputFilter:
class OutputPythonFilter(OutputFilter):
- pass
+ def process(self, stream):
+ for token in stream:
+ if isinstance(token, sql.Statement):
+ yield sql.Token(T.Literal, f'{self.varname}{self.count} = ')
+ yield sql.Token(T.Literal, '"')
+ yield from token
+ yield sql.Token(T.Literal, '"')
+ yield sql.Token(T.Whitespace, '\n')
+ self.count += 1
+ else:
+ yield token
class OutputPHPFilter(OutputFilter):
varname_prefix = '$'
+
+ def process(self, stream):
+ for token in stream:
+ if isinstance(token, sql.Statement):
+ yield sql.Token(T.Literal, f'{self.varname}{self.count} = ')
+ yield sql.Token(T.Literal, '"')
+ yield from token
+ yield sql.Token(T.Literal, '";')
+ yield sql.Token(T.Whitespace, '\n')
+ self.count += 1
+ else:
+ yield token
diff --git a/sqlparse/filters/reindent.py b/sqlparse/filters/reindent.py
index cccce71..ec48bb0 100644
--- a/sqlparse/filters/reindent.py
+++ b/sqlparse/filters/reindent.py
@@ -21,4 +21,7 @@ class ReindentFilter:
def _flatten_up_to_token(self, token):
"""Yields all tokens up to token but excluding current."""
- pass
+ for t in self._curr_stmt.flatten():
+ if t == token:
+ break
+ yield t
diff --git a/sqlparse/formatter.py b/sqlparse/formatter.py
index 71775a6..3bb9f2d 100644
--- a/sqlparse/formatter.py
+++ b/sqlparse/formatter.py
@@ -5,7 +5,18 @@ from sqlparse.exceptions import SQLParseError
def validate_options(options):
"""Validates options."""
- pass
+ if not isinstance(options, dict):
+ raise SQLParseError("Options must be a dictionary")
+
+ valid_options = {
+ 'keyword_case', 'identifier_case', 'strip_comments', 'reindent',
+ 'reindent_aligned', 'use_space_around_operators', 'strip_whitespace',
+ 'comma_first', 'wrap_after', 'right_margin'
+ }
+
+ for key in options:
+ if key not in valid_options:
+ raise SQLParseError(f"Invalid option: {key}")
def build_filter_stack(stack, options):
@@ -15,4 +26,50 @@ def build_filter_stack(stack, options):
stack: :class:`~sqlparse.filters.FilterStack` instance
options: Dictionary with options validated by validate_options.
"""
- pass
+ # Keyword case
+ if 'keyword_case' in options:
+ stack.preprocess.append(
+ filters.KeywordCaseFilter(case=options['keyword_case']))
+
+ # Identifier case
+ if 'identifier_case' in options:
+ stack.preprocess.append(
+ filters.IdentifierCaseFilter(case=options['identifier_case']))
+
+ # Strip comments
+ if options.get('strip_comments', False):
+ stack.preprocess.append(filters.StripCommentsFilter())
+
+ # Reindent
+ if options.get('reindent', False):
+ stack.enable_grouping()
+ stack.stmtprocess.append(
+ filters.ReindentFilter(
+ char=options.get('indent_char', ' '),
+ width=options.get('indent_width', 2),
+ wrap_after=options.get('wrap_after', 0),
+ comma_first=options.get('comma_first', False)
+ )
+ )
+
+ # Reindent aligned
+ if options.get('reindent_aligned', False):
+ stack.enable_grouping()
+ stack.stmtprocess.append(filters.AlignedIndentFilter())
+
+ # Use space around operators
+ if options.get('use_space_around_operators', False):
+ stack.enable_grouping()
+ stack.stmtprocess.append(filters.SpacesAroundOperatorsFilter())
+
+ # Strip whitespace
+ if options.get('strip_whitespace', False):
+ stack.postprocess.append(filters.StripWhitespaceFilter())
+
+ # Right margin
+ if 'right_margin' in options:
+ stack.enable_grouping()
+ stack.stmtprocess.append(
+ filters.RightMarginFilter(width=options['right_margin']))
+
+ return stack
diff --git a/sqlparse/lexer.py b/sqlparse/lexer.py
index cc76039..794625a 100644
--- a/sqlparse/lexer.py
+++ b/sqlparse/lexer.py
@@ -16,28 +16,35 @@ class Lexer:
def get_default_instance(cls):
"""Returns the lexer instance used internally
by the sqlparse core functions."""
- pass
+ with cls._lock:
+ if cls._default_instance is None:
+ cls._default_instance = cls()
+ return cls._default_instance
def default_initialization(self):
"""Initialize the lexer with default dictionaries.
Useful if you need to revert custom syntax settings."""
- pass
+ self.clear()
+ self.add_keywords(keywords.KEYWORDS)
+ self.add_keywords(keywords.KEYWORDS_COMMON)
+ self.add_keywords(keywords.KEYWORDS_ORACLE)
def clear(self):
"""Clear all syntax configurations.
Useful if you want to load a reduced set of syntax configurations.
After this call, regexps and keyword dictionaries need to be loaded
to make the lexer functional again."""
- pass
+ self.keywords = {}
+ self.SQL_REGEX = []
def set_SQL_REGEX(self, SQL_REGEX):
"""Set the list of regex that will parse the SQL."""
- pass
+ self.SQL_REGEX = SQL_REGEX
def add_keywords(self, keywords):
"""Add keyword dictionaries. Keywords are looked up in the same order
that dictionaries were added."""
- pass
+ self.keywords.update(keywords)
def is_keyword(self, value):
"""Checks for a keyword.
@@ -45,7 +52,10 @@ class Lexer:
If the given value is in one of the KEYWORDS_* dictionary
it's considered a keyword. Otherwise, tokens.Name is returned.
"""
- pass
+ val = value.upper()
+ if val in self.keywords:
+ return self.keywords[val]
+ return tokens.Name
def get_tokens(self, text, encoding=None):
"""
@@ -60,7 +70,27 @@ class Lexer:
``stack`` is the initial stack (default: ``['root']``)
"""
- pass
+ if isinstance(text, TextIOBase):
+ text = text.read()
+
+ if encoding is not None:
+ if isinstance(text, str):
+ text = text.encode(encoding)
+ text = text.decode(encoding)
+
+ iterable = enumerate(text)
+ for pos, char in iterable:
+ for regex, token_type in self.SQL_REGEX:
+ match = regex.match(text, pos)
+ if match:
+ value = match.group()
+ if token_type is tokens.Name.Symbol:
+ token_type = self.is_keyword(value)
+ yield token_type, value
+ consume(iterable, len(value) - 1)
+ break
+ else:
+ yield tokens.Error, char
def tokenize(sql, encoding=None):
@@ -69,4 +99,5 @@ def tokenize(sql, encoding=None):
Tokenize *sql* using the :class:`Lexer` and return a 2-tuple stream
of ``(token type, value)`` items.
"""
- pass
+ lexer = Lexer.get_default_instance()
+ return lexer.get_tokens(sql, encoding)
diff --git a/sqlparse/sql.py b/sqlparse/sql.py
index 44fef09..2fa139f 100644
--- a/sqlparse/sql.py
+++ b/sqlparse/sql.py
@@ -10,11 +10,11 @@ class NameAliasMixin:
def get_real_name(self):
"""Returns the real name (object name) of this identifier."""
- pass
+ return self.get_name()
def get_alias(self):
"""Returns the alias for this identifier or ``None``."""
- pass
+ return None
class Token:
@@ -50,7 +50,7 @@ class Token:
def flatten(self):
"""Resolve subgroups."""
- pass
+ yield self
def match(self, ttype, values, regex=False):
"""Checks whether the token matches the given arguments.
@@ -64,7 +64,17 @@ class Token:
If *regex* is ``True`` (default is ``False``) the given values are
treated as regular expressions.
"""
- pass
+ if self.ttype is not ttype:
+ return False
+
+ values = [values] if isinstance(values, str) else values
+
+ if regex:
+ return any(re.search(val, self.value, re.IGNORECASE if self.is_keyword else 0) for val in values)
+ elif self.is_keyword:
+ return self.normalized in [v.upper() for v in values]
+ else:
+ return self.value in values
def within(self, group_cls):
"""Returns ``True`` if this token is within *group_cls*.
@@ -72,15 +82,25 @@ class Token:
Use this method for example to check if an identifier is within
a function: ``t.within(sql.Function)``.
"""
- pass
+ parent = self.parent
+ while parent:
+ if isinstance(parent, group_cls):
+ return True
+ parent = parent.parent
+ return False
def is_child_of(self, other):
"""Returns ``True`` if this token is a direct child of *other*."""
- pass
+ return self.parent == other
def has_ancestor(self, other):
"""Returns ``True`` if *other* is in this tokens ancestry."""
- pass
+ parent = self.parent
+ while parent:
+ if parent == other:
+ return True
+ parent = parent.parent
+ return False
class TokenList(Token):
@@ -108,22 +128,58 @@ class TokenList(Token):
def _pprint_tree(self, max_depth=None, depth=0, f=None, _pre=''):
"""Pretty-print the object tree."""
- pass
+ if max_depth and depth > max_depth:
+ return
+
+ indent = ' ' * depth
+ if f is None:
+ f = sys.stdout
+
+ f.write(f'{_pre}{indent}{self.__class__.__name__} {self}\n')
+
+ for token in self.tokens:
+ if isinstance(token, TokenList):
+ token._pprint_tree(max_depth, depth + 1, f)
+ else:
+ f.write(f'{indent} {token.__class__.__name__} {token}\n')
def get_token_at_offset(self, offset):
"""Returns the token that is on position offset."""
- pass
+ idx = 0
+ for token in self.flatten():
+ end = idx + len(token.value)
+ if idx <= offset < end:
+ return token
+ idx = end
+ return None
def flatten(self):
"""Generator yielding ungrouped tokens.
This method is recursively called for all child tokens.
"""
- pass
+ for token in self.tokens:
+ if isinstance(token, TokenList):
+ yield from token.flatten()
+ else:
+ yield token
def _token_matching(self, funcs, start=0, end=None, reverse=False):
"""next token that match functions"""
- pass
+ funcs = [funcs] if not isinstance(funcs, (list, tuple)) else funcs
+ end = end or len(self.tokens)
+
+ if reverse:
+ indices = range(end - 1, start - 1, -1)
+ else:
+ indices = range(start, end)
+
+ for idx in indices:
+ token = self.tokens[idx]
+ for func in funcs:
+ if func(token):
+ return idx, token
+ return None, None
def token_first(self, skip_ws=True, skip_cm=False):
"""Returns the first child token.
diff --git a/sqlparse/utils.py b/sqlparse/utils.py
index a99ca61..b214be8 100644
--- a/sqlparse/utils.py
+++ b/sqlparse/utils.py
@@ -23,12 +23,18 @@ def split_unquoted_newlines(stmt):
Unlike str.splitlines(), this will ignore CR/LF/CR+LF if the requisite
character is inside of a string."""
- pass
+ if not stmt:
+ return []
+
+ parts = SPLIT_REGEX.split(stmt)
+ return [part for part in parts if LINE_MATCH.match(part) is None]
def remove_quotes(val):
"""Helper that removes surrounding quotes from strings."""
- pass
+ if val and len(val) > 1 and val[0] == val[-1] and val[0] in ("'", '"'):
+ return val[1:-1]
+ return val
def recurse(*cls):
@@ -37,7 +43,14 @@ def recurse(*cls):
:param cls: Classes to not recurse over
:return: function
"""
- pass
+ def wrap(f):
+ def wrapped(tlist):
+ for token in tlist.tokens:
+ if not isinstance(token, cls):
+ f(token)
+ return tlist
+ return wrapped
+ return wrap
def imt(token, i=None, m=None, t=None):
@@ -48,9 +61,27 @@ def imt(token, i=None, m=None, t=None):
:param t: TokenType or Tuple/List of TokenTypes
:return: bool
"""
- pass
+ if i is not None:
+ if isinstance(i, (list, tuple)):
+ return isinstance(token, tuple(i))
+ return isinstance(token, i)
+
+ if m is not None:
+ if isinstance(m[0], (list, tuple)):
+ return any(token.match(*_m) for _m in m)
+ return token.match(*m)
+
+ if t is not None:
+ if isinstance(t, (list, tuple)):
+ return token.ttype in t
+ return token.ttype is t
+
+ return False
def consume(iterator, n):
"""Advance the iterator n-steps ahead. If n is none, consume entirely."""
- pass
+ if n is None:
+ deque(iterator, maxlen=0)
+ else:
+ next(itertools.islice(iterator, n, n), None)