summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorAmjith Ramanujam <amjith@newrelic.com>2014-12-31 23:17:50 -0800
committerAmjith Ramanujam <amjith@newrelic.com>2014-12-31 23:17:50 -0800
commit1cfdb4c55a4bebfe6536969acc2936131e7395c6 (patch)
treeea728300eb0ac2986271a7fcc4055303a50c01e6
parent01792765fa3d1d2803ca4d0dfed48ed50f38adfc (diff)
Extract tables from JOIN statements.
-rw-r--r--TODO4
-rw-r--r--pgcli/packages/parseutils.py17
-rw-r--r--tests/test_parseutils.py16
3 files changed, 31 insertions, 6 deletions
diff --git a/TODO b/TODO
index a1b3964f..af678173 100644
--- a/TODO
+++ b/TODO
@@ -2,9 +2,9 @@
* [o] Add JOIN to the list of keywords and provide proper autocompletion for it.
* [ ] Add a page to keep track of changelog in pgcli.com
* [ ] Refactor to sqlcompletion to consume the text from left to right and use a state machine to suggest cols or tables instead of relying on hacks.
-* [ ] Extract tables should also look for table names after the JOIN keyword.
+* [X] Extract tables should also look for table names after the JOIN keyword.
- SELECT * FROM some_very_long_table_name s JOIN another_fairly_long_name a ON s.id = a.num;
-* [ ] Test if the aliases are identified correctly if the AS keyword is used
+* [X] Test if the aliases are identified correctly if the AS keyword is used
- SELECT * FROM my_table AS m WHERE m.a > 5;
* [ ] ON keyword should suggest aliases. This is something we don't currently support since a collection of aliases is not maintained.
* [ ] Add a page to keep track of changelog in pgcli.com
diff --git a/pgcli/packages/parseutils.py b/pgcli/packages/parseutils.py
index b3086a63..13ac5fc4 100644
--- a/pgcli/packages/parseutils.py
+++ b/pgcli/packages/parseutils.py
@@ -74,19 +74,21 @@ def extract_from_part(parsed, stop_at_punctuation=True):
if is_subselect(item):
for x in extract_from_part(item, stop_at_punctuation):
yield x
+ elif stop_at_punctuation and item.ttype is Punctuation:
+ raise StopIteration
# An incomplete nested select won't be recognized correctly as a
# sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes
# the second FROM to trigger this elif condition resulting in a
# StopIteration. So we need to ignore the keyword if the keyword
# FROM.
- elif stop_at_punctuation and item.ttype is Punctuation:
- raise StopIteration
- elif item.ttype is Keyword and item.value.upper() != 'FROM':
+ # Also 'SELECT * FROM abc JOIN def' will trigger this elif
+ # condition. So we need to ignore the keyword JOIN.
+ elif item.ttype is Keyword and item.value.upper() not in ('FROM', 'JOIN'):
raise StopIteration
else:
yield item
elif ((item.ttype is Keyword or item.ttype is Keyword.DML) and
- item.value.upper() in ('FROM', 'INTO', 'UPDATE', 'TABLE', )):
+ item.value.upper() in ('FROM', 'INTO', 'UPDATE', 'TABLE', 'JOIN',)):
tbl_prefix_seen = True
# 'SELECT a, FROM abc' will detect FROM as part of the column list.
# So this check here is necessary.
@@ -118,6 +120,13 @@ def extract_table_identifiers(token_stream):
# extract_tables is inspired from examples in the sqlparse lib.
def extract_tables(sql, include_alias=False):
+ """Extract the table names from an SQL statment.
+
+ Returns a list of table names if include_alias=False (default).
+ If include_alias=True, then a dictionary is returned where the keys are
+ aliases and values are real table names.
+
+ """
parsed = sqlparse.parse(sql)
if not parsed:
return []
diff --git a/tests/test_parseutils.py b/tests/test_parseutils.py
index c195a876..ae6e8f5a 100644
--- a/tests/test_parseutils.py
+++ b/tests/test_parseutils.py
@@ -1,5 +1,10 @@
from pgcli.packages.parseutils import extract_tables
+
+def test_empty_string():
+ tables = extract_tables('')
+ assert tables == []
+
def test_simple_select_single_table():
tables = extract_tables('select * from abc')
assert tables == ['abc']
@@ -31,3 +36,14 @@ def test_simple_insert_single_table():
def test_simple_update_table():
tables = extract_tables('update abc set id = 1')
assert tables == ['abc']
+
+def test_join_table():
+ tables = extract_tables('SELECT * FROM abc a JOIN def d ON s.id = a.num')
+ assert tables == ['abc', 'def']
+
+def test_join_as_table():
+ expected = {'m': 'my_table'}
+ assert extract_tables(
+ 'SELECT * FROM my_table AS m WHERE m.a > 5') == expected.values()
+ assert extract_tables(
+ 'SELECT * FROM my_table AS m WHERE m.a > 5', True) == expected