diff options
author | Amjith Ramanujam <amjith@newrelic.com> | 2014-12-31 23:17:50 -0800 |
---|---|---|
committer | Amjith Ramanujam <amjith@newrelic.com> | 2014-12-31 23:17:50 -0800 |
commit | 1cfdb4c55a4bebfe6536969acc2936131e7395c6 (patch) | |
tree | ea728300eb0ac2986271a7fcc4055303a50c01e6 | |
parent | 01792765fa3d1d2803ca4d0dfed48ed50f38adfc (diff) |
Extract tables from JOIN statements.
-rw-r--r-- | TODO | 4 | ||||
-rw-r--r-- | pgcli/packages/parseutils.py | 17 | ||||
-rw-r--r-- | tests/test_parseutils.py | 16 |
3 files changed, 31 insertions, 6 deletions
@@ -2,9 +2,9 @@ * [o] Add JOIN to the list of keywords and provide proper autocompletion for it. * [ ] Add a page to keep track of changelog in pgcli.com * [ ] Refactor to sqlcompletion to consume the text from left to right and use a state machine to suggest cols or tables instead of relying on hacks. -* [ ] Extract tables should also look for table names after the JOIN keyword. +* [X] Extract tables should also look for table names after the JOIN keyword. - SELECT * FROM some_very_long_table_name s JOIN another_fairly_long_name a ON s.id = a.num; -* [ ] Test if the aliases are identified correctly if the AS keyword is used +* [X] Test if the aliases are identified correctly if the AS keyword is used - SELECT * FROM my_table AS m WHERE m.a > 5; * [ ] ON keyword should suggest aliases. This is something we don't currently support since a collection of aliases is not maintained. * [ ] Add a page to keep track of changelog in pgcli.com diff --git a/pgcli/packages/parseutils.py b/pgcli/packages/parseutils.py index b3086a63..13ac5fc4 100644 --- a/pgcli/packages/parseutils.py +++ b/pgcli/packages/parseutils.py @@ -74,19 +74,21 @@ def extract_from_part(parsed, stop_at_punctuation=True): if is_subselect(item): for x in extract_from_part(item, stop_at_punctuation): yield x + elif stop_at_punctuation and item.ttype is Punctuation: + raise StopIteration # An incomplete nested select won't be recognized correctly as a # sub-select. eg: 'SELECT * FROM (SELECT id FROM user'. This causes # the second FROM to trigger this elif condition resulting in a # StopIteration. So we need to ignore the keyword if the keyword # FROM. - elif stop_at_punctuation and item.ttype is Punctuation: - raise StopIteration - elif item.ttype is Keyword and item.value.upper() != 'FROM': + # Also 'SELECT * FROM abc JOIN def' will trigger this elif + # condition. So we need to ignore the keyword JOIN. + elif item.ttype is Keyword and item.value.upper() not in ('FROM', 'JOIN'): raise StopIteration else: yield item elif ((item.ttype is Keyword or item.ttype is Keyword.DML) and - item.value.upper() in ('FROM', 'INTO', 'UPDATE', 'TABLE', )): + item.value.upper() in ('FROM', 'INTO', 'UPDATE', 'TABLE', 'JOIN',)): tbl_prefix_seen = True # 'SELECT a, FROM abc' will detect FROM as part of the column list. # So this check here is necessary. @@ -118,6 +120,13 @@ def extract_table_identifiers(token_stream): # extract_tables is inspired from examples in the sqlparse lib. def extract_tables(sql, include_alias=False): + """Extract the table names from an SQL statment. + + Returns a list of table names if include_alias=False (default). + If include_alias=True, then a dictionary is returned where the keys are + aliases and values are real table names. + + """ parsed = sqlparse.parse(sql) if not parsed: return [] diff --git a/tests/test_parseutils.py b/tests/test_parseutils.py index c195a876..ae6e8f5a 100644 --- a/tests/test_parseutils.py +++ b/tests/test_parseutils.py @@ -1,5 +1,10 @@ from pgcli.packages.parseutils import extract_tables + +def test_empty_string(): + tables = extract_tables('') + assert tables == [] + def test_simple_select_single_table(): tables = extract_tables('select * from abc') assert tables == ['abc'] @@ -31,3 +36,14 @@ def test_simple_insert_single_table(): def test_simple_update_table(): tables = extract_tables('update abc set id = 1') assert tables == ['abc'] + +def test_join_table(): + tables = extract_tables('SELECT * FROM abc a JOIN def d ON s.id = a.num') + assert tables == ['abc', 'def'] + +def test_join_as_table(): + expected = {'m': 'my_table'} + assert extract_tables( + 'SELECT * FROM my_table AS m WHERE m.a > 5') == expected.values() + assert extract_tables( + 'SELECT * FROM my_table AS m WHERE m.a > 5', True) == expected |