summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHarel Ben-Attia <harelba@gmail.com>2014-11-15 17:45:07 -0500
committerHarel Ben-Attia <harelba@gmail.com>2014-11-15 17:45:07 -0500
commitfaf39043dbe559d97428e193e13c38ec61c1b0a1 (patch)
tree4ac0a94286a3e1a757b51c2e8266682c25390933
parentb0d28b175fb386eb858bd00d6ef4efadd9443e30 (diff)
Added reuse of loaded data + stdin injection to module API + tests
-rwxr-xr-xbin/q143
-rwxr-xr-xtest/test-suite151
2 files changed, 242 insertions, 52 deletions
diff --git a/bin/q b/bin/q
index 62cb5c4..e579e55 100755
--- a/bin/q
+++ b/bin/q
@@ -201,6 +201,11 @@ class CannotUnzipStdInException(Exception):
def __init__(self):
pass
+class UnprovidedStdInException(Exception):
+
+ def __init__(self):
+ pass
+
class EmptyDataException(Exception):
def __init__(self):
@@ -579,7 +584,7 @@ def normalized_filename(filename):
class TableCreator(object):
- def __init__(self, db, filenames_str, line_splitter, skip_header=False, gzipped=False, encoding='UTF-8', mode='fluffy', expected_column_count=None, input_delimiter=None):
+ def __init__(self, db, filenames_str, line_splitter, skip_header=False, gzipped=False, encoding='UTF-8', mode='fluffy', expected_column_count=None, input_delimiter=None,stdin_file=None,stdin_filename='-'):
self.db = db
self.filenames_str = filenames_str
self.skip_header = skip_header
@@ -590,6 +595,9 @@ class TableCreator(object):
self.mode = mode
self.expected_column_count = expected_column_count
self.input_delimiter = input_delimiter
+ self.stdin_file = stdin_file
+ self.stdin_filename = stdin_filename
+
self.column_inferer = TableColumnInferer(
mode, expected_column_count, input_delimiter, skip_header)
@@ -613,8 +621,8 @@ class TableCreator(object):
# for each filename (or pattern)
for fileglob in filenames:
# Allow either stdin or a glob match
- if fileglob == '-':
- files_to_go_over = ['-']
+ if fileglob == self.stdin_filename:
+ files_to_go_over = [self.stdin_filename]
else:
files_to_go_over = glob.glob(fileglob)
@@ -629,8 +637,10 @@ class TableCreator(object):
self.lines_read = 0
# Check if it's standard input or a file
- if filename == '-':
- f = sys.stdin
+ if filename == self.stdin_filename:
+ if self.stdin_file is None:
+ raise UnprovidedStdInException()
+ f = self.stdin_file
if self.gzipped:
raise CannotUnzipStdInException()
else:
@@ -827,12 +837,26 @@ class QError(object):
self.msg = msg
self.errorcode = errorcode
+class QDataLoad(object):
+ def __init__(self,filename,start_time,end_time):
+ self.filename = filename
+ self.start_time = start_time
+ self.end_time = end_time
+
+ def duration(self):
+ return self.end_time - self.start_time
+
+ def __str__(self):
+ return "DataLoad<'%s' at %s (took %4.3f seconds)>" % (self.filename,self.start_time,self.duration())
+ __repr__ = __str__
+
class QOutput(object):
- def __init__(self,data,column_name_list,warnings,error):
+ def __init__(self,data,column_name_list,warnings,error,data_loads):
self.data = data
self.column_name_list = column_name_list
self.warnings = warnings
self.error = error
+ self.data_loads = data_loads
def error_occurred(self):
return self.error is not None
@@ -848,9 +872,10 @@ class QOutput(object):
else:
s.append("row_count=None")
if self.column_name_list is not None:
- s.append("column_names=%s" % ",".join(self.column_name_list))
+ s.append("column_names=[%s]" % ",".join(self.column_name_list))
else:
s.append("column_names=None")
+ s.append("data_load_count=%s" % len(self.data_loads))
return "QOutput<%s>" % ",".join(s)
__repr__ = __str__
@@ -869,8 +894,7 @@ class QTextAsData(object):
query_encoding=locale.getpreferredencoding(),
debug=False,
- analyze_only=False,
- auto_load = True
+ analyze_only=False
):
self.skip_header = skip_header
self.delimiter = delimiter
@@ -892,10 +916,11 @@ class QTextAsData(object):
self.dialect_name = 'q'
csv.register_dialect(self.dialect_name, **self.q_dialect)
- self.table_creators = None
+ self.table_creators = {}
+
+ # Create DB object
+ self.db = Sqlite3DB()
- if auto_load:
- self.load_data()
input_quoting_modes = { 'minimal' : csv.QUOTE_MINIMAL,
'all' : csv.QUOTE_ALL,
@@ -922,50 +947,66 @@ class QTextAsData(object):
return dialect
- def load_data(self):
- ## RLRL TODO separate load and execute phases
- pass
+ def ensure_data_is_loaded(self,query_str,stdin_file,stdin_filename='-'):
+ data_loads = []
+ # Create SQL statment
+ sql_object = Sql('%s' % query_str)
- def execute(self,query_str):
+ # Create a line splitter
+ line_splitter = LineSplitter(self.delimiter, self.expected_column_count)
+
+ # Get each "table name" which is actually the file name
+ for filename in sql_object.qtable_names:
+ start_time = time.time()
+
+ # reuse already loaded data, except for stdin file data (stdin file data will always
+ # be reloaded and overwritten)
+ if filename in self.table_creators.keys() and filename != stdin_filename:
+ continue
+
+ # Create the matching database table and populate it
+ table_creator = TableCreator(
+ self.db, filename, line_splitter, self.skip_header, self.gzipped_input, self.input_encoding,
+ mode=self.parsing_mode, expected_column_count=self.expected_column_count,
+ input_delimiter=self.delimiter,stdin_file = stdin_file,stdin_filename = stdin_filename)
+ table_creator.populate(self.dialect_name,self.analyze_only)
+ self.table_creators[filename] = table_creator
+
+ data_loads.append(QDataLoad(filename,start_time,time.time()))
+
+ if self.debug:
+ print >>sys.stderr, "TIMING - populate time is %4.3f" % (
+ time.time() - start_time)
+
+ return data_loads
+
+ def materialize_sql_object(self,sql_object):
+ for filename in sql_object.qtable_names:
+ sql_object.set_effective_table_name(filename,self.table_creators[filename].table_name)
+
+ def execute(self,query_str,stdin_file=None,stdin_filename='-'):
warnings = []
error = None
+ data_loads = []
db_results_obj = None
try:
- # Create DB object
- self.db = Sqlite3DB()
+ data_loads += self.ensure_data_is_loaded(query_str,stdin_file=stdin_file,stdin_filename=stdin_filename)
# Create SQL statment
- sql_object = Sql('%s' % query_str)
-
- # Create a line splitter
- line_splitter = LineSplitter(self.delimiter, self.expected_column_count)
-
- self.table_creators = []
- # Get each "table name" which is actually the file name
- for filename in sql_object.qtable_names:
- # Create the matching database table and populate it
- table_creator = TableCreator(self.db, filename, line_splitter, self.skip_header, self.gzipped_input, self.input_encoding,
- mode=self.parsing_mode, expected_column_count=self.expected_column_count, input_delimiter=self.delimiter)
- start_time = time.time()
- table_creator.populate(self.dialect_name,self.analyze_only)
- self.table_creators.append(table_creator)
- if self.debug:
- print >>sys.stderr, "TIMING - populate time is %4.3f" % (
- time.time() - start_time)
-
- # Replace the logical table name with the real table name
- sql_object.set_effective_table_name(filename, table_creator.table_name)
+ sql_object = Sql(query_str)
+
+ self.materialize_sql_object(sql_object)
if self.analyze_only:
- for table_creator in self.table_creators:
+ for filename,table_creator in self.table_creators.iteritems():
column_names = table_creator.column_inferer.get_column_names()
print "Table for file: %s" % normalized_filename(table_creator.filenames_str)
for k in column_names:
column_type = table_creator.column_inferer.get_column_dict()[k]
print " `%s` - %s" % (k, self.db.type_names[column_type].lower())
- return QOutput(data = None,column_name_list = None,warnings = warnings, error = None)
+ return QOutput(data = None,column_name_list = None,warnings = warnings, error = None, data_loads = [])
# Execute the query and fetch the data
db_results_obj = sql_object.execute_and_fetch(self.db)
@@ -987,34 +1028,36 @@ class QTextAsData(object):
error = QError(e,"Bad header row: %s" % e.msg,35)
except CannotUnzipStdInException,e:
error = QError(e,"Cannot decompress standard input. Pipe the input through zcat in order to decompress.",36)
+ except UnprovidedStdInException,e:
+ error = QError(e,"Standard Input must be provided in order to use it as a table",61)
except CouldNotConvertStringToNumericValueException,e:
error = QError(e,"Could not convert string to a numeric value. Did you use `-w nonnumeric` with unquoted string values? Error: %s" % e.msg,58)
except CouldNotParseInputException,e:
error = QError(e,"Could not parse the input. Please make sure to set the proper -w input-wrapping parameter for your input, and that you use the proper input encoding (-e). Error: %s" % e.msg,59)
except KeyboardInterrupt,e:
warnings.append(QWarning(e,"Interrupted"))
- except Error, e:
- error = QError(e,e.msg,199)
+ except Exception, e:
+ error = QError(e,str(e),199)
if db_results_obj is not None:
return QOutput(
data = db_results_obj.results,
column_name_list=db_results_obj.query_column_names,
warnings = warnings,
- error = error)
+ error = error,
+ data_loads = data_loads)
else:
- return QOutput(data = None, column_name_list = None,warnings = warnings,error = error)
+ return QOutput(data = None, column_name_list = None,warnings = warnings,error = error , data_loads = data_loads)
- def done(self):
- if self.table_creators is None:
- return
+ def unload(self):
- for table_creator in self.table_creators:
+ for filename,table_creator in self.table_creators.iteritems():
try:
table_creator.drop_table()
except:
# Support no-table select queries
pass
+ self.table_creators = {}
def analyze(self):
pass
@@ -1354,8 +1397,8 @@ def run_standalone():
debug=DEBUG,
analyze_only=options.analyze_only)
- q_output = q_query.execute(query_str)
- q_query.done()
+ q_output = q_query.execute(query_str,stdin_file=sys.stdin)
+ q_query.unload()
q_output_printer = QOutputPrinter(
delimiter=options.output_delimiter,
diff --git a/test/test-suite b/test/test-suite
index 373d766..39f82ce 100755
--- a/test/test-suite
+++ b/test/test-suite
@@ -18,6 +18,10 @@ import os
import time
from tempfile import NamedTemporaryFile
import locale
+import pprint
+
+sys.path.append(os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])),'..','bin'))
+from qtextasdata import QTextAsData,QOutput,QOutputPrinter
# q uses this encoding as the default output encoding. Some of the tests use it in order to
# make sure that the output is correctly encoded
@@ -1576,7 +1580,149 @@ class SqlTests(AbstractQTestCase):
self.assertEquals(len(o), 10*10*10)
self.cleanup(tmpfile2)
-
+
+class BasicModuleTests(AbstractQTestCase):
+
+ def test_simple_query(self):
+ tmpfile = self.create_file_with_data("a b c\n1 2 3\n4 5 6")
+
+ q = QTextAsData(skip_header=True,delimiter=' ')
+ r = q.execute('select * from %s' % tmpfile.name)
+
+ self.assertFalse(r.error_occurred())
+ self.assertEquals(len(r.warnings),0)
+ self.assertEquals(len(r.data),2)
+ self.assertEquals(r.column_name_list,['a','b','c'])
+ self.assertEquals(r.data,[(1,2,3),(4,5,6)])
+ self.assertEquals(len(r.data_loads),1)
+ self.assertEquals(r.data_loads[0].filename,tmpfile.name)
+
+ self.cleanup(tmpfile)
+
+ def test_loaded_data_reuse(self):
+ tmpfile = self.create_file_with_data("a b c\n1 2 3\n4 5 6")
+
+ q = QTextAsData(skip_header=True,delimiter=' ')
+ r1 = q.execute('select * from %s' % tmpfile.name)
+
+ r2 = q.execute('select * from %s' % tmpfile.name)
+
+ self.assertFalse(r1.error_occurred())
+ self.assertEquals(len(r1.warnings),0)
+ self.assertEquals(len(r1.data),2)
+ self.assertEquals(r1.column_name_list,['a','b','c'])
+ self.assertEquals(r1.data,[(1,2,3),(4,5,6)])
+ self.assertEquals(r1.data_loads[0].filename,tmpfile.name)
+
+ self.assertFalse(r2.error_occurred())
+ self.assertEquals(len(r1.data_loads),1)
+ self.assertEquals(r1.data_loads[0].filename,tmpfile.name)
+ self.assertEquals(len(r2.data_loads),0)
+ self.assertEquals(r2.data,r1.data)
+ self.assertEquals(r2.column_name_list,r2.column_name_list)
+ self.assertEquals(len(r2.warnings),0)
+
+ self.cleanup(tmpfile)
+
+ def test_stdin_injection(self):
+ tmpfile = self.create_file_with_data("a b c\n1 2 3\n4 5 6")
+
+ q = QTextAsData(skip_header=True,delimiter=' ')
+ r = q.execute('select * from -',stdin_file=file(tmpfile.name,'rb'))
+
+ self.assertFalse(r.error_occurred())
+ self.assertEquals(len(r.warnings),0)
+ self.assertEquals(len(r.data),2)
+ self.assertEquals(r.column_name_list,['a','b','c'])
+ self.assertEquals(r.data,[(1,2,3),(4,5,6)])
+ self.assertEquals(len(r.data_loads),1)
+ self.assertEquals(r.data_loads[0].filename,'-')
+
+ self.cleanup(tmpfile)
+
+ def test_named_stdin_injection(self):
+ tmpfile = self.create_file_with_data("a b c\n1 2 3\n4 5 6")
+
+ q = QTextAsData(skip_header=True,delimiter=' ')
+ r = q.execute('select a from my_stdin_data',stdin_file=file(tmpfile.name,'rb'),stdin_filename='my_stdin_data')
+
+ self.assertFalse(r.error_occurred())
+ self.assertEquals(len(r.warnings),0)
+ self.assertEquals(len(r.data),2)
+ self.assertEquals(r.column_name_list,['a'])
+ self.assertEquals(r.data,[(1,),(4,)])
+ self.assertEquals(len(r.data_loads),1)
+ self.assertEquals(r.data_loads[0].filename,'my_stdin_data')
+
+ self.cleanup(tmpfile)
+
+ def test_stdin_injection_isolation(self):
+ tmpfile1 = self.create_file_with_data("a b c\n1 2 3\n4 5 6")
+ tmpfile2 = self.create_file_with_data("d e f\n7 8 9\n10 11 12")
+
+ q = QTextAsData(skip_header=True,delimiter=' ')
+ r1 = q.execute('select * from -',stdin_file=file(tmpfile1.name,'rb'))
+
+ self.assertFalse(r1.error_occurred())
+ self.assertEquals(len(r1.warnings),0)
+ self.assertEquals(len(r1.data),2)
+ self.assertEquals(r1.column_name_list,['a','b','c'])
+ self.assertEquals(r1.data,[(1,2,3),(4,5,6)])
+ self.assertEquals(len(r1.data_loads),1)
+ self.assertEquals(r1.data_loads[0].filename,'-')
+
+ r2 = q.execute('select * from -',stdin_file=file(tmpfile2.name,'rb'))
+
+ self.assertFalse(r2.error_occurred())
+ self.assertEquals(len(r2.warnings),0)
+ self.assertEquals(len(r2.data),2)
+ self.assertEquals(r2.column_name_list,['d','e','f'])
+ self.assertEquals(r2.data,[(7,8,9),(10,11,12)])
+ # There should be another data load, even though it's the same 'filename' as before
+ self.assertEquals(len(r2.data_loads),1)
+ self.assertEquals(r2.data_loads[0].filename,'-')
+
+ self.cleanup(tmpfile1)
+ self.cleanup(tmpfile2)
+
+ def test_multiple_stdin_injection(self):
+ tmpfile1 = self.create_file_with_data("a b c\n1 2 3\n4 5 6")
+ tmpfile2 = self.create_file_with_data("d e f\n7 8 9\n10 11 12")
+
+ q = QTextAsData(skip_header=True,delimiter=' ')
+ r1 = q.execute('select * from my_stdin_data1',stdin_file=file(tmpfile1.name,'rb'),stdin_filename='my_stdin_data1')
+
+ self.assertFalse(r1.error_occurred())
+ self.assertEquals(len(r1.warnings),0)
+ self.assertEquals(len(r1.data),2)
+ self.assertEquals(r1.column_name_list,['a','b','c'])
+ self.assertEquals(r1.data,[(1,2,3),(4,5,6)])
+ self.assertEquals(len(r1.data_loads),1)
+ self.assertEquals(r1.data_loads[0].filename,'my_stdin_data1')
+
+ r2 = q.execute('select * from my_stdin_data2',stdin_file=file(tmpfile2.name,'rb'),stdin_filename='my_stdin_data2')
+
+ self.assertFalse(r2.error_occurred())
+ self.assertEquals(len(r2.warnings),0)
+ self.assertEquals(len(r2.data),2)
+ self.assertEquals(r2.column_name_list,['d','e','f'])
+ self.assertEquals(r2.data,[(7,8,9),(10,11,12)])
+ # There should be another data load, even though it's the same 'filename' as before
+ self.assertEquals(len(r2.data_loads),1)
+ self.assertEquals(r2.data_loads[0].filename,'my_stdin_data2')
+
+ r3 = q.execute('select aa.*,bb.* from my_stdin_data1 aa join my_stdin_data2 bb')
+
+ self.assertFalse(r3.error_occurred())
+ self.assertEquals(len(r3.warnings),0)
+ self.assertEquals(len(r3.data),4)
+ self.assertEquals(r3.column_name_list,['a','b','c','d','e','f'])
+ self.assertEquals(r3.data,[(1,2,3,7,8,9),(1,2,3,10,11,12),(4,5,6,7,8,9),(4,5,6,10,11,12)])
+ self.assertEquals(len(r3.data_loads),0)
+
+ self.cleanup(tmpfile1)
+ self.cleanup(tmpfile2)
+
def suite():
tl = unittest.TestLoader()
@@ -1584,7 +1730,8 @@ def suite():
parsing_mode = tl.loadTestsFromTestCase(ParsingModeTests)
sql = tl.loadTestsFromTestCase(SqlTests)
formatting = tl.loadTestsFromTestCase(FormattingTests)
- return unittest.TestSuite([basic_stuff, parsing_mode, sql, formatting])
+ basic_module_stuff = tl.loadTestsFromTestCase(BasicModuleTests)
+ return unittest.TestSuite([basic_module_stuff, basic_stuff, parsing_mode, sql, formatting])
if __name__ == '__main__':
unittest.TextTestRunner(verbosity=2).run(suite())