diff options
author | Harel Ben-Attia <harelba@gmail.com> | 2014-11-15 17:45:07 -0500 |
---|---|---|
committer | Harel Ben-Attia <harelba@gmail.com> | 2014-11-15 17:45:07 -0500 |
commit | faf39043dbe559d97428e193e13c38ec61c1b0a1 (patch) | |
tree | 4ac0a94286a3e1a757b51c2e8266682c25390933 | |
parent | b0d28b175fb386eb858bd00d6ef4efadd9443e30 (diff) |
Added reuse of loaded data + stdin injection to module API + tests
-rwxr-xr-x | bin/q | 143 | ||||
-rwxr-xr-x | test/test-suite | 151 |
2 files changed, 242 insertions, 52 deletions
@@ -201,6 +201,11 @@ class CannotUnzipStdInException(Exception): def __init__(self): pass +class UnprovidedStdInException(Exception): + + def __init__(self): + pass + class EmptyDataException(Exception): def __init__(self): @@ -579,7 +584,7 @@ def normalized_filename(filename): class TableCreator(object): - def __init__(self, db, filenames_str, line_splitter, skip_header=False, gzipped=False, encoding='UTF-8', mode='fluffy', expected_column_count=None, input_delimiter=None): + def __init__(self, db, filenames_str, line_splitter, skip_header=False, gzipped=False, encoding='UTF-8', mode='fluffy', expected_column_count=None, input_delimiter=None,stdin_file=None,stdin_filename='-'): self.db = db self.filenames_str = filenames_str self.skip_header = skip_header @@ -590,6 +595,9 @@ class TableCreator(object): self.mode = mode self.expected_column_count = expected_column_count self.input_delimiter = input_delimiter + self.stdin_file = stdin_file + self.stdin_filename = stdin_filename + self.column_inferer = TableColumnInferer( mode, expected_column_count, input_delimiter, skip_header) @@ -613,8 +621,8 @@ class TableCreator(object): # for each filename (or pattern) for fileglob in filenames: # Allow either stdin or a glob match - if fileglob == '-': - files_to_go_over = ['-'] + if fileglob == self.stdin_filename: + files_to_go_over = [self.stdin_filename] else: files_to_go_over = glob.glob(fileglob) @@ -629,8 +637,10 @@ class TableCreator(object): self.lines_read = 0 # Check if it's standard input or a file - if filename == '-': - f = sys.stdin + if filename == self.stdin_filename: + if self.stdin_file is None: + raise UnprovidedStdInException() + f = self.stdin_file if self.gzipped: raise CannotUnzipStdInException() else: @@ -827,12 +837,26 @@ class QError(object): self.msg = msg self.errorcode = errorcode +class QDataLoad(object): + def __init__(self,filename,start_time,end_time): + self.filename = filename + self.start_time = start_time + self.end_time = end_time + + def duration(self): + return self.end_time - self.start_time + + def __str__(self): + return "DataLoad<'%s' at %s (took %4.3f seconds)>" % (self.filename,self.start_time,self.duration()) + __repr__ = __str__ + class QOutput(object): - def __init__(self,data,column_name_list,warnings,error): + def __init__(self,data,column_name_list,warnings,error,data_loads): self.data = data self.column_name_list = column_name_list self.warnings = warnings self.error = error + self.data_loads = data_loads def error_occurred(self): return self.error is not None @@ -848,9 +872,10 @@ class QOutput(object): else: s.append("row_count=None") if self.column_name_list is not None: - s.append("column_names=%s" % ",".join(self.column_name_list)) + s.append("column_names=[%s]" % ",".join(self.column_name_list)) else: s.append("column_names=None") + s.append("data_load_count=%s" % len(self.data_loads)) return "QOutput<%s>" % ",".join(s) __repr__ = __str__ @@ -869,8 +894,7 @@ class QTextAsData(object): query_encoding=locale.getpreferredencoding(), debug=False, - analyze_only=False, - auto_load = True + analyze_only=False ): self.skip_header = skip_header self.delimiter = delimiter @@ -892,10 +916,11 @@ class QTextAsData(object): self.dialect_name = 'q' csv.register_dialect(self.dialect_name, **self.q_dialect) - self.table_creators = None + self.table_creators = {} + + # Create DB object + self.db = Sqlite3DB() - if auto_load: - self.load_data() input_quoting_modes = { 'minimal' : csv.QUOTE_MINIMAL, 'all' : csv.QUOTE_ALL, @@ -922,50 +947,66 @@ class QTextAsData(object): return dialect - def load_data(self): - ## RLRL TODO separate load and execute phases - pass + def ensure_data_is_loaded(self,query_str,stdin_file,stdin_filename='-'): + data_loads = [] + # Create SQL statment + sql_object = Sql('%s' % query_str) - def execute(self,query_str): + # Create a line splitter + line_splitter = LineSplitter(self.delimiter, self.expected_column_count) + + # Get each "table name" which is actually the file name + for filename in sql_object.qtable_names: + start_time = time.time() + + # reuse already loaded data, except for stdin file data (stdin file data will always + # be reloaded and overwritten) + if filename in self.table_creators.keys() and filename != stdin_filename: + continue + + # Create the matching database table and populate it + table_creator = TableCreator( + self.db, filename, line_splitter, self.skip_header, self.gzipped_input, self.input_encoding, + mode=self.parsing_mode, expected_column_count=self.expected_column_count, + input_delimiter=self.delimiter,stdin_file = stdin_file,stdin_filename = stdin_filename) + table_creator.populate(self.dialect_name,self.analyze_only) + self.table_creators[filename] = table_creator + + data_loads.append(QDataLoad(filename,start_time,time.time())) + + if self.debug: + print >>sys.stderr, "TIMING - populate time is %4.3f" % ( + time.time() - start_time) + + return data_loads + + def materialize_sql_object(self,sql_object): + for filename in sql_object.qtable_names: + sql_object.set_effective_table_name(filename,self.table_creators[filename].table_name) + + def execute(self,query_str,stdin_file=None,stdin_filename='-'): warnings = [] error = None + data_loads = [] db_results_obj = None try: - # Create DB object - self.db = Sqlite3DB() + data_loads += self.ensure_data_is_loaded(query_str,stdin_file=stdin_file,stdin_filename=stdin_filename) # Create SQL statment - sql_object = Sql('%s' % query_str) - - # Create a line splitter - line_splitter = LineSplitter(self.delimiter, self.expected_column_count) - - self.table_creators = [] - # Get each "table name" which is actually the file name - for filename in sql_object.qtable_names: - # Create the matching database table and populate it - table_creator = TableCreator(self.db, filename, line_splitter, self.skip_header, self.gzipped_input, self.input_encoding, - mode=self.parsing_mode, expected_column_count=self.expected_column_count, input_delimiter=self.delimiter) - start_time = time.time() - table_creator.populate(self.dialect_name,self.analyze_only) - self.table_creators.append(table_creator) - if self.debug: - print >>sys.stderr, "TIMING - populate time is %4.3f" % ( - time.time() - start_time) - - # Replace the logical table name with the real table name - sql_object.set_effective_table_name(filename, table_creator.table_name) + sql_object = Sql(query_str) + + self.materialize_sql_object(sql_object) if self.analyze_only: - for table_creator in self.table_creators: + for filename,table_creator in self.table_creators.iteritems(): column_names = table_creator.column_inferer.get_column_names() print "Table for file: %s" % normalized_filename(table_creator.filenames_str) for k in column_names: column_type = table_creator.column_inferer.get_column_dict()[k] print " `%s` - %s" % (k, self.db.type_names[column_type].lower()) - return QOutput(data = None,column_name_list = None,warnings = warnings, error = None) + return QOutput(data = None,column_name_list = None,warnings = warnings, error = None, data_loads = []) # Execute the query and fetch the data db_results_obj = sql_object.execute_and_fetch(self.db) @@ -987,34 +1028,36 @@ class QTextAsData(object): error = QError(e,"Bad header row: %s" % e.msg,35) except CannotUnzipStdInException,e: error = QError(e,"Cannot decompress standard input. Pipe the input through zcat in order to decompress.",36) + except UnprovidedStdInException,e: + error = QError(e,"Standard Input must be provided in order to use it as a table",61) except CouldNotConvertStringToNumericValueException,e: error = QError(e,"Could not convert string to a numeric value. Did you use `-w nonnumeric` with unquoted string values? Error: %s" % e.msg,58) except CouldNotParseInputException,e: error = QError(e,"Could not parse the input. Please make sure to set the proper -w input-wrapping parameter for your input, and that you use the proper input encoding (-e). Error: %s" % e.msg,59) except KeyboardInterrupt,e: warnings.append(QWarning(e,"Interrupted")) - except Error, e: - error = QError(e,e.msg,199) + except Exception, e: + error = QError(e,str(e),199) if db_results_obj is not None: return QOutput( data = db_results_obj.results, column_name_list=db_results_obj.query_column_names, warnings = warnings, - error = error) + error = error, + data_loads = data_loads) else: - return QOutput(data = None, column_name_list = None,warnings = warnings,error = error) + return QOutput(data = None, column_name_list = None,warnings = warnings,error = error , data_loads = data_loads) - def done(self): - if self.table_creators is None: - return + def unload(self): - for table_creator in self.table_creators: + for filename,table_creator in self.table_creators.iteritems(): try: table_creator.drop_table() except: # Support no-table select queries pass + self.table_creators = {} def analyze(self): pass @@ -1354,8 +1397,8 @@ def run_standalone(): debug=DEBUG, analyze_only=options.analyze_only) - q_output = q_query.execute(query_str) - q_query.done() + q_output = q_query.execute(query_str,stdin_file=sys.stdin) + q_query.unload() q_output_printer = QOutputPrinter( delimiter=options.output_delimiter, diff --git a/test/test-suite b/test/test-suite index 373d766..39f82ce 100755 --- a/test/test-suite +++ b/test/test-suite @@ -18,6 +18,10 @@ import os import time from tempfile import NamedTemporaryFile import locale +import pprint + +sys.path.append(os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])),'..','bin')) +from qtextasdata import QTextAsData,QOutput,QOutputPrinter # q uses this encoding as the default output encoding. Some of the tests use it in order to # make sure that the output is correctly encoded @@ -1576,7 +1580,149 @@ class SqlTests(AbstractQTestCase): self.assertEquals(len(o), 10*10*10) self.cleanup(tmpfile2) - + +class BasicModuleTests(AbstractQTestCase): + + def test_simple_query(self): + tmpfile = self.create_file_with_data("a b c\n1 2 3\n4 5 6") + + q = QTextAsData(skip_header=True,delimiter=' ') + r = q.execute('select * from %s' % tmpfile.name) + + self.assertFalse(r.error_occurred()) + self.assertEquals(len(r.warnings),0) + self.assertEquals(len(r.data),2) + self.assertEquals(r.column_name_list,['a','b','c']) + self.assertEquals(r.data,[(1,2,3),(4,5,6)]) + self.assertEquals(len(r.data_loads),1) + self.assertEquals(r.data_loads[0].filename,tmpfile.name) + + self.cleanup(tmpfile) + + def test_loaded_data_reuse(self): + tmpfile = self.create_file_with_data("a b c\n1 2 3\n4 5 6") + + q = QTextAsData(skip_header=True,delimiter=' ') + r1 = q.execute('select * from %s' % tmpfile.name) + + r2 = q.execute('select * from %s' % tmpfile.name) + + self.assertFalse(r1.error_occurred()) + self.assertEquals(len(r1.warnings),0) + self.assertEquals(len(r1.data),2) + self.assertEquals(r1.column_name_list,['a','b','c']) + self.assertEquals(r1.data,[(1,2,3),(4,5,6)]) + self.assertEquals(r1.data_loads[0].filename,tmpfile.name) + + self.assertFalse(r2.error_occurred()) + self.assertEquals(len(r1.data_loads),1) + self.assertEquals(r1.data_loads[0].filename,tmpfile.name) + self.assertEquals(len(r2.data_loads),0) + self.assertEquals(r2.data,r1.data) + self.assertEquals(r2.column_name_list,r2.column_name_list) + self.assertEquals(len(r2.warnings),0) + + self.cleanup(tmpfile) + + def test_stdin_injection(self): + tmpfile = self.create_file_with_data("a b c\n1 2 3\n4 5 6") + + q = QTextAsData(skip_header=True,delimiter=' ') + r = q.execute('select * from -',stdin_file=file(tmpfile.name,'rb')) + + self.assertFalse(r.error_occurred()) + self.assertEquals(len(r.warnings),0) + self.assertEquals(len(r.data),2) + self.assertEquals(r.column_name_list,['a','b','c']) + self.assertEquals(r.data,[(1,2,3),(4,5,6)]) + self.assertEquals(len(r.data_loads),1) + self.assertEquals(r.data_loads[0].filename,'-') + + self.cleanup(tmpfile) + + def test_named_stdin_injection(self): + tmpfile = self.create_file_with_data("a b c\n1 2 3\n4 5 6") + + q = QTextAsData(skip_header=True,delimiter=' ') + r = q.execute('select a from my_stdin_data',stdin_file=file(tmpfile.name,'rb'),stdin_filename='my_stdin_data') + + self.assertFalse(r.error_occurred()) + self.assertEquals(len(r.warnings),0) + self.assertEquals(len(r.data),2) + self.assertEquals(r.column_name_list,['a']) + self.assertEquals(r.data,[(1,),(4,)]) + self.assertEquals(len(r.data_loads),1) + self.assertEquals(r.data_loads[0].filename,'my_stdin_data') + + self.cleanup(tmpfile) + + def test_stdin_injection_isolation(self): + tmpfile1 = self.create_file_with_data("a b c\n1 2 3\n4 5 6") + tmpfile2 = self.create_file_with_data("d e f\n7 8 9\n10 11 12") + + q = QTextAsData(skip_header=True,delimiter=' ') + r1 = q.execute('select * from -',stdin_file=file(tmpfile1.name,'rb')) + + self.assertFalse(r1.error_occurred()) + self.assertEquals(len(r1.warnings),0) + self.assertEquals(len(r1.data),2) + self.assertEquals(r1.column_name_list,['a','b','c']) + self.assertEquals(r1.data,[(1,2,3),(4,5,6)]) + self.assertEquals(len(r1.data_loads),1) + self.assertEquals(r1.data_loads[0].filename,'-') + + r2 = q.execute('select * from -',stdin_file=file(tmpfile2.name,'rb')) + + self.assertFalse(r2.error_occurred()) + self.assertEquals(len(r2.warnings),0) + self.assertEquals(len(r2.data),2) + self.assertEquals(r2.column_name_list,['d','e','f']) + self.assertEquals(r2.data,[(7,8,9),(10,11,12)]) + # There should be another data load, even though it's the same 'filename' as before + self.assertEquals(len(r2.data_loads),1) + self.assertEquals(r2.data_loads[0].filename,'-') + + self.cleanup(tmpfile1) + self.cleanup(tmpfile2) + + def test_multiple_stdin_injection(self): + tmpfile1 = self.create_file_with_data("a b c\n1 2 3\n4 5 6") + tmpfile2 = self.create_file_with_data("d e f\n7 8 9\n10 11 12") + + q = QTextAsData(skip_header=True,delimiter=' ') + r1 = q.execute('select * from my_stdin_data1',stdin_file=file(tmpfile1.name,'rb'),stdin_filename='my_stdin_data1') + + self.assertFalse(r1.error_occurred()) + self.assertEquals(len(r1.warnings),0) + self.assertEquals(len(r1.data),2) + self.assertEquals(r1.column_name_list,['a','b','c']) + self.assertEquals(r1.data,[(1,2,3),(4,5,6)]) + self.assertEquals(len(r1.data_loads),1) + self.assertEquals(r1.data_loads[0].filename,'my_stdin_data1') + + r2 = q.execute('select * from my_stdin_data2',stdin_file=file(tmpfile2.name,'rb'),stdin_filename='my_stdin_data2') + + self.assertFalse(r2.error_occurred()) + self.assertEquals(len(r2.warnings),0) + self.assertEquals(len(r2.data),2) + self.assertEquals(r2.column_name_list,['d','e','f']) + self.assertEquals(r2.data,[(7,8,9),(10,11,12)]) + # There should be another data load, even though it's the same 'filename' as before + self.assertEquals(len(r2.data_loads),1) + self.assertEquals(r2.data_loads[0].filename,'my_stdin_data2') + + r3 = q.execute('select aa.*,bb.* from my_stdin_data1 aa join my_stdin_data2 bb') + + self.assertFalse(r3.error_occurred()) + self.assertEquals(len(r3.warnings),0) + self.assertEquals(len(r3.data),4) + self.assertEquals(r3.column_name_list,['a','b','c','d','e','f']) + self.assertEquals(r3.data,[(1,2,3,7,8,9),(1,2,3,10,11,12),(4,5,6,7,8,9),(4,5,6,10,11,12)]) + self.assertEquals(len(r3.data_loads),0) + + self.cleanup(tmpfile1) + self.cleanup(tmpfile2) + def suite(): tl = unittest.TestLoader() @@ -1584,7 +1730,8 @@ def suite(): parsing_mode = tl.loadTestsFromTestCase(ParsingModeTests) sql = tl.loadTestsFromTestCase(SqlTests) formatting = tl.loadTestsFromTestCase(FormattingTests) - return unittest.TestSuite([basic_stuff, parsing_mode, sql, formatting]) + basic_module_stuff = tl.loadTestsFromTestCase(BasicModuleTests) + return unittest.TestSuite([basic_module_stuff, basic_stuff, parsing_mode, sql, formatting]) if __name__ == '__main__': unittest.TextTestRunner(verbosity=2).run(suite()) |