From 08cf2bb897ee7ce4ebac22b4ac088a8a456f086f Mon Sep 17 00:00:00 2001 From: Harel Ben-Attia Date: Sat, 15 Nov 2014 19:41:47 -0500 Subject: Allow different input params for each loaded file + loading file manually + fixed modeling of query encoding + tests --- bin/q | 216 +++++++++++++++++++++++++++++++++----------------------- test/test-suite | 102 +++++++++++++++++++++++--- 2 files changed, 222 insertions(+), 96 deletions(-) diff --git a/bin/q b/bin/q index e579e55..e7e5fd6 100755 --- a/bin/q +++ b/bin/q @@ -195,6 +195,14 @@ class BadHeaderException(Exception): def __str(self): return repr(self.msg) +class EncodedQueryException(Exception): + + def __init__(self, msg): + self.msg = msg + + def __str(self): + return repr(self.msg) + class CannotUnzipStdInException(Exception): @@ -836,6 +844,7 @@ class QError(object): self.exception = exception self.msg = msg self.errorcode = errorcode + self.traceback = traceback.format_exc() class QDataLoad(object): def __init__(self,filename,start_time,end_time): @@ -851,7 +860,7 @@ class QDataLoad(object): __repr__ = __str__ class QOutput(object): - def __init__(self,data,column_name_list,warnings,error,data_loads): + def __init__(self,data=None,column_name_list=None,warnings=[],error=None,data_loads=[]): self.data = data self.column_name_list = column_name_list self.warnings = warnings @@ -879,23 +888,12 @@ class QOutput(object): return "QOutput<%s>" % ",".join(s) __repr__ = __str__ -class QTextAsData(object): - def __init__(self, - skip_header=False, - delimiter=' ', - input_encoding='UTF-8', - gzipped_input=False, - parsing_mode='relaxed', - expected_column_count=None, - keep_leading_whitespace_in_values=False, - disable_double_double_quoting=False, - disable_escaped_double_quoting=False, - input_quoting_mode='minimal', - - query_encoding=locale.getpreferredencoding(), - debug=False, - analyze_only=False - ): +class QInputParams(object): + def __init__(self,skip_header=False, + delimiter=' ',input_encoding='UTF-8',gzipped_input=False,parsing_mode='relaxed', + expected_column_count=None,keep_leading_whitespace_in_values=False, + disable_double_double_quoting=False,disable_escaped_double_quoting=False, + input_quoting_mode='minimal',stdin_file=None,stdin_filename='-'): self.skip_header = skip_header self.delimiter = delimiter self.input_encoding = input_encoding @@ -907,14 +905,25 @@ class QTextAsData(object): self.disable_escaped_double_quoting = disable_escaped_double_quoting self.input_quoting_mode = input_quoting_mode - self.query_encoding = query_encoding - self.debug = debug - self.analyze_only = analyze_only + def merged_with(self,input_params): + params = QInputParams(**self.__dict__) + if input_params is not None: + params.__dict__.update(**input_params.__dict__) + return params + + def __str__(self): + return "QInputParams<%s>" % str(self.__dict__) + + def __repr__(self): + return "QInputParams(...)" - self.q_dialect = self.determine_proper_dialect() - # TODO Isolate dialects of each Q instance - self.dialect_name = 'q' - csv.register_dialect(self.dialect_name, **self.q_dialect) +class QTextAsData(object): + def __init__(self,default_input_params=QInputParams(), + analyze_only=False + ): + self.default_input_params = default_input_params + + self.analyze_only = analyze_only self.table_creators = {} @@ -928,56 +937,67 @@ class QTextAsData(object): # ourselves instead of letting the csv module try to identify the types 'none' : csv.QUOTE_NONE } - def determine_proper_dialect(self): + def determine_proper_dialect(self,input_params): - input_quoting_mode_csv_numeral = QTextAsData.input_quoting_modes[self.input_quoting_mode] + input_quoting_mode_csv_numeral = QTextAsData.input_quoting_modes[input_params.input_quoting_mode] - if self.keep_leading_whitespace_in_values: + if input_params.keep_leading_whitespace_in_values: skip_initial_space = False else: skip_initial_space = True dialect = {'skipinitialspace': skip_initial_space, - 'delimiter': self.delimiter, 'quotechar': '"' } + 'delimiter': input_params.delimiter, 'quotechar': '"' } dialect['quoting'] = input_quoting_mode_csv_numeral - dialect['doublequote'] = self.disable_double_double_quoting + dialect['doublequote'] = input_params.disable_double_double_quoting - if self.disable_escaped_double_quoting: + if input_params.disable_escaped_double_quoting: dialect['escapechar'] = '\\' return dialect - def ensure_data_is_loaded(self,query_str,stdin_file,stdin_filename='-'): - data_loads = [] + def get_dialect_id(self,filename): + return 'q_dialect_%s' % filename - # Create SQL statment - sql_object = Sql('%s' % query_str) + def _load_data(self,filename,input_params=QInputParams(),stdin_file=None,stdin_filename='-'): + start_time = time.time() + + q_dialect = self.determine_proper_dialect(input_params) + dialect_id = self.get_dialect_id(filename) + csv.register_dialect(dialect_id, **q_dialect) # Create a line splitter - line_splitter = LineSplitter(self.delimiter, self.expected_column_count) + line_splitter = LineSplitter(input_params.delimiter, input_params.expected_column_count) - # Get each "table name" which is actually the file name - for filename in sql_object.qtable_names: - start_time = time.time() + # reuse already loaded data, except for stdin file data (stdin file data will always + # be reloaded and overwritten) + if filename in self.table_creators.keys() and filename != stdin_filename: + return None - # reuse already loaded data, except for stdin file data (stdin file data will always - # be reloaded and overwritten) - if filename in self.table_creators.keys() and filename != stdin_filename: - continue + # Create the matching database table and populate it + table_creator = TableCreator( + self.db, filename, line_splitter, input_params.skip_header, input_params.gzipped_input, input_params.input_encoding, + mode=input_params.parsing_mode, expected_column_count=input_params.expected_column_count, + input_delimiter=input_params.delimiter,stdin_file = stdin_file,stdin_filename = stdin_filename) + table_creator.populate(dialect_id,self.analyze_only) + self.table_creators[filename] = table_creator + + return QDataLoad(filename,start_time,time.time()) + + def load_data(self,filename,input_params=QInputParams()): + self._load_data(filename,input_params) - # Create the matching database table and populate it - table_creator = TableCreator( - self.db, filename, line_splitter, self.skip_header, self.gzipped_input, self.input_encoding, - mode=self.parsing_mode, expected_column_count=self.expected_column_count, - input_delimiter=self.delimiter,stdin_file = stdin_file,stdin_filename = stdin_filename) - table_creator.populate(self.dialect_name,self.analyze_only) - self.table_creators[filename] = table_creator + def ensure_data_is_loaded(self,query_str,input_params,stdin_file,stdin_filename='-'): + data_loads = [] - data_loads.append(QDataLoad(filename,start_time,time.time())) + # Create SQL statment + sql_object = Sql('%s' % query_str) - if self.debug: - print >>sys.stderr, "TIMING - populate time is %4.3f" % ( - time.time() - start_time) + # Get each "table name" which is actually the file name + for filename in sql_object.qtable_names: + data_load = self._load_data(filename,input_params,stdin_file=stdin_file,stdin_filename=stdin_filename) + if data_load is not None: + data_loads.append(data_load) return data_loads @@ -985,14 +1005,24 @@ class QTextAsData(object): for filename in sql_object.qtable_names: sql_object.set_effective_table_name(filename,self.table_creators[filename].table_name) - def execute(self,query_str,stdin_file=None,stdin_filename='-'): + def execute(self,query_str,input_params=None,stdin_file=None,stdin_filename='-'): warnings = [] error = None data_loads = [] db_results_obj = None + effective_input_params = self.default_input_params.merged_with(input_params) + + if type(query_str) != unicode: + try: + # Hueristic attempt to auto convert the query to unicode before failing + query_str = query_str.decode('utf-8') + except: + error = QError(EncodedQueryException(),"Query should be in unicode. Please make sure to provide a unicode literal string or decode it using proper the character encoding.",91) + return QOutput(error = error) + try: - data_loads += self.ensure_data_is_loaded(query_str,stdin_file=stdin_file,stdin_filename=stdin_filename) + data_loads += self.ensure_data_is_loaded(query_str,effective_input_params,stdin_file=stdin_file,stdin_filename=stdin_filename) # Create SQL statment sql_object = Sql(query_str) @@ -1006,7 +1036,7 @@ class QTextAsData(object): for k in column_names: column_type = table_creator.column_inferer.get_column_dict()[k] print " `%s` - %s" % (k, self.db.type_names[column_type].lower()) - return QOutput(data = None,column_name_list = None,warnings = warnings, error = None, data_loads = []) + return QOutput() # Execute the query and fetch the data db_results_obj = sql_object.execute_and_fetch(self.db) @@ -1018,7 +1048,7 @@ class QTextAsData(object): except sqlite3.OperationalError, e: msg = str(e) error = QError(e,"query error: %s" % msg,1) - if "no such column" in msg and self.skip_header: + if "no such column" in msg and effective_input_params.skip_header: warnings.append(QWarning(e,'Warning - There seems to be a "no such column" error, and -H (header line) exists. Please make sure that you are using the column names from the header line and not the default (cXX) column names')) except ColumnCountMismatchException, e: error = QError(e,e.msg,2) @@ -1037,7 +1067,7 @@ class QTextAsData(object): except KeyboardInterrupt,e: warnings.append(QWarning(e,"Interrupted")) except Exception, e: - error = QError(e,str(e),199) + error = QError(e,repr(e),199) if db_results_obj is not None: return QOutput( @@ -1047,7 +1077,7 @@ class QTextAsData(object): error = error, data_loads = data_loads) else: - return QOutput(data = None, column_name_list = None,warnings = warnings,error = error , data_loads = data_loads) + return QOutput(warnings = warnings,error = error , data_loads = data_loads) def unload(self): @@ -1083,25 +1113,36 @@ def quote_nonnumeric_func(output_delimiter,v): def quote_all_func(output_delimiter,v): return '"%s"' % (v) -class QOutputPrinter(object): - output_quoting_modes = { 'minimal' : quote_minimal_func, - 'all' : quote_all_func, - 'nonnumeric' : quote_nonnumeric_func, - 'none' : quote_none_func } - +class QOutputParams(object): def __init__(self, - delimiter=' ', - beautify=False, - output_quoting_mode='minimal', - formatting=None, - output_header=False): + delimiter=' ', + beautify=False, + output_quoting_mode='minimal', + formatting=None, + output_header=False): self.delimiter = delimiter self.beautify = beautify self.output_quoting_mode = output_quoting_mode self.formatting = formatting self.output_header = output_header - self.output_field_quoting_func = QOutputPrinter.output_quoting_modes[self.output_quoting_mode] + def __str__(self): + return "QOutputParams<%s>" % str(self.__dict__) + + def __repr__(self): + return "QOutputParams(...)" + + +class QOutputPrinter(object): + output_quoting_modes = { 'minimal' : quote_minimal_func, + 'all' : quote_all_func, + 'nonnumeric' : quote_nonnumeric_func, + 'none' : quote_none_func } + + def __init__(self,output_params): + self.output_params = output_params + + self.output_field_quoting_func = QOutputPrinter.output_quoting_modes[output_params.output_quoting_mode] def print_output(self,f,results): @@ -1118,17 +1159,17 @@ class QOutputPrinter(object): return # If the user requested beautifying the output - if self.beautify: - max_lengths = determine_max_col_lengths(data,self.output_field_quoting_func,self.delimiter) + if self.output_params.beautify: + max_lengths = determine_max_col_lengths(data,self.output_field_quoting_func,self.output_params.delimiter) - if self.formatting: + if self.output_params.formatting: formatting_dict = dict( - [(x.split("=")[0], x.split("=")[1]) for x in self.formatting.split(",")]) + [(x.split("=")[0], x.split("=")[1]) for x in self.output_params.formatting.split(",")]) else: formatting_dict = None try: - if self.output_header and results.column_name_list is not None: + if self.output_params.output_header and results.column_name_list is not None: data.insert(0,results.column_name_list) for rownum, row in enumerate(data): row_str = [] @@ -1136,17 +1177,17 @@ class QOutputPrinter(object): if formatting_dict is not None and str(i + 1) in formatting_dict.keys(): fmt_str = formatting_dict[str(i + 1)] else: - if self.beautify: + if self.output_params.beautify: fmt_str = "%%-%ss" % max_lengths[i] else: fmt_str = "%s" if col is not None: - row_str.append(fmt_str % self.output_field_quoting_func(self.delimiter,col)) + row_str.append(fmt_str % self.output_field_quoting_func(self.output_params.delimiter,col)) else: row_str.append(fmt_str % "") - f.write(self.delimiter.join(row_str) + "\n") + f.write(self.output_params.delimiter.join(row_str) + "\n") except (UnicodeEncodeError, UnicodeError), e: print >>sys.stderr, "Cannot encode data. Error:%s" % e sys.exit(3) @@ -1382,8 +1423,7 @@ def run_standalone(): # (since no input delimiter means any whitespace) options.output_delimiter = " " - q_query = QTextAsData( - skip_header=options.skip_header, + default_input_params = QInputParams(skip_header=options.skip_header, delimiter=options.delimiter, input_encoding=options.encoding, gzipped_input=options.gzipped, @@ -1392,21 +1432,23 @@ def run_standalone(): keep_leading_whitespace_in_values=options.keep_leading_whitespace_in_values, disable_double_double_quoting=options.disable_double_double_quoting, disable_escaped_double_quoting=options.disable_escaped_double_quoting, - input_quoting_mode=options.input_quoting_mode, - query_encoding=options.query_encoding, - debug=DEBUG, - analyze_only=options.analyze_only) + input_quoting_mode=options.input_quoting_mode) + + q_query = QTextAsData( + default_input_params=default_input_params,analyze_only=options.analyze_only) q_output = q_query.execute(query_str,stdin_file=sys.stdin) q_query.unload() - q_output_printer = QOutputPrinter( + output_params = QOutputParams( delimiter=options.output_delimiter, beautify=options.beautify, output_quoting_mode=options.output_quoting_mode, formatting=options.formatting, output_header=options.output_header) + q_output_printer = QOutputPrinter(output_params) + try: q_output_printer.print_output(STDOUT,q_output) diff --git a/test/test-suite b/test/test-suite index 39f82ce..cc3ecca 100755 --- a/test/test-suite +++ b/test/test-suite @@ -21,7 +21,7 @@ import locale import pprint sys.path.append(os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])),'..','bin')) -from qtextasdata import QTextAsData,QOutput,QOutputPrinter +from qtextasdata import QTextAsData,QOutput,QOutputPrinter,QInputParams # q uses this encoding as the default output encoding. Some of the tests use it in order to # make sure that the output is correctly encoded @@ -422,6 +422,7 @@ class BasicTests(AbstractQTestCase): def test_generated_column_name_warning_when_header_line_exists(self): tmpfile = self.create_file_with_data(sample_data_with_header) cmd = '../bin/q -d , "select c3 from %s" -H' % tmpfile.name + retcode, o, e = run_command(cmd) self.assertNotEquals(retcode, 0) @@ -1347,7 +1348,7 @@ class ParsingModeTests(AbstractQTestCase): tmpfile = self.create_file_with_data(uneven_ls_output) cmd = '../bin/q -m relaxed "select count(*) from %s" -A' % tmpfile.name retcode, o, e = run_command(cmd) - + self.assertEquals(retcode, 0) self.assertEquals(len(e), 0) @@ -1586,7 +1587,7 @@ class BasicModuleTests(AbstractQTestCase): def test_simple_query(self): tmpfile = self.create_file_with_data("a b c\n1 2 3\n4 5 6") - q = QTextAsData(skip_header=True,delimiter=' ') + q = QTextAsData(QInputParams(skip_header=True,delimiter=' ')) r = q.execute('select * from %s' % tmpfile.name) self.assertFalse(r.error_occurred()) @@ -1602,7 +1603,7 @@ class BasicModuleTests(AbstractQTestCase): def test_loaded_data_reuse(self): tmpfile = self.create_file_with_data("a b c\n1 2 3\n4 5 6") - q = QTextAsData(skip_header=True,delimiter=' ') + q = QTextAsData(QInputParams(skip_header=True,delimiter=' ')) r1 = q.execute('select * from %s' % tmpfile.name) r2 = q.execute('select * from %s' % tmpfile.name) @@ -1627,7 +1628,7 @@ class BasicModuleTests(AbstractQTestCase): def test_stdin_injection(self): tmpfile = self.create_file_with_data("a b c\n1 2 3\n4 5 6") - q = QTextAsData(skip_header=True,delimiter=' ') + q = QTextAsData(QInputParams(skip_header=True,delimiter=' ')) r = q.execute('select * from -',stdin_file=file(tmpfile.name,'rb')) self.assertFalse(r.error_occurred()) @@ -1643,7 +1644,7 @@ class BasicModuleTests(AbstractQTestCase): def test_named_stdin_injection(self): tmpfile = self.create_file_with_data("a b c\n1 2 3\n4 5 6") - q = QTextAsData(skip_header=True,delimiter=' ') + q = QTextAsData(QInputParams(skip_header=True,delimiter=' ')) r = q.execute('select a from my_stdin_data',stdin_file=file(tmpfile.name,'rb'),stdin_filename='my_stdin_data') self.assertFalse(r.error_occurred()) @@ -1660,7 +1661,7 @@ class BasicModuleTests(AbstractQTestCase): tmpfile1 = self.create_file_with_data("a b c\n1 2 3\n4 5 6") tmpfile2 = self.create_file_with_data("d e f\n7 8 9\n10 11 12") - q = QTextAsData(skip_header=True,delimiter=' ') + q = QTextAsData(QInputParams(skip_header=True,delimiter=' ')) r1 = q.execute('select * from -',stdin_file=file(tmpfile1.name,'rb')) self.assertFalse(r1.error_occurred()) @@ -1689,7 +1690,7 @@ class BasicModuleTests(AbstractQTestCase): tmpfile1 = self.create_file_with_data("a b c\n1 2 3\n4 5 6") tmpfile2 = self.create_file_with_data("d e f\n7 8 9\n10 11 12") - q = QTextAsData(skip_header=True,delimiter=' ') + q = QTextAsData(QInputParams(skip_header=True,delimiter=' ')) r1 = q.execute('select * from my_stdin_data1',stdin_file=file(tmpfile1.name,'rb'),stdin_filename='my_stdin_data1') self.assertFalse(r1.error_occurred()) @@ -1722,7 +1723,90 @@ class BasicModuleTests(AbstractQTestCase): self.cleanup(tmpfile1) self.cleanup(tmpfile2) - + + def test_different_input_params_for_different_files(self): + tmpfile1 = self.create_file_with_data("a b c\n1 2 3\n4 5 6") + tmpfile2 = self.create_file_with_data("7\t8\t9\n10\t11\t12") + + q = QTextAsData(QInputParams(skip_header=True,delimiter=' ')) + + q.load_data(tmpfile1.name,QInputParams(skip_header=True,delimiter=' ')) + q.load_data(tmpfile2.name,QInputParams(skip_header=False,delimiter='\t')) + + r = q.execute('select aa.*,bb.* from %s aa join %s bb' % (tmpfile1.name,tmpfile2.name)) + + self.assertFalse(r.error_occurred()) + self.assertEquals(len(r.warnings),0) + self.assertEquals(len(r.data),4) + self.assertEquals(r.column_name_list,['a','b','c','c1','c2','c3']) + self.assertEquals(r.data,[(1,2,3,7,8,9),(1,2,3,10,11,12),(4,5,6,7,8,9),(4,5,6,10,11,12)]) + self.assertEquals(len(r.data_loads),0) + + self.cleanup(tmpfile1) + self.cleanup(tmpfile2) + + def test_different_input_params_for_different_files(self): + tmpfile1 = self.create_file_with_data("a b c\n1 2 3\n4 5 6") + tmpfile2 = self.create_file_with_data("7\t8\t9\n10\t11\t12") + + q = QTextAsData() + + q.load_data(tmpfile1.name,QInputParams(skip_header=True,delimiter=' ')) + q.load_data(tmpfile2.name,QInputParams(skip_header=False,delimiter='\t')) + + r = q.execute('select aa.*,bb.* from %s aa join %s bb' % (tmpfile1.name,tmpfile2.name)) + + self.assertFalse(r.error_occurred()) + self.assertEquals(len(r.warnings),0) + self.assertEquals(len(r.data),4) + self.assertEquals(r.column_name_list,['a','b','c','c1','c2','c3']) + self.assertEquals(r.data,[(1,2,3,7,8,9),(1,2,3,10,11,12),(4,5,6,7,8,9),(4,5,6,10,11,12)]) + self.assertEquals(len(r.data_loads),0) + + self.cleanup(tmpfile1) + self.cleanup(tmpfile2) + + def test_input_params_override(self): + tmpfile = self.create_file_with_data("a b c\n1 2 3\n4 5 6") + + default_input_params = QInputParams() + + for k in default_input_params.__dict__.keys(): + setattr(default_input_params,k,'GARBAGE') + + q = QTextAsData(default_input_params) + + r = q.execute('select * from %s' % tmpfile.name) + + self.assertTrue(r.error_occurred()) + + overwriting_input_params = QInputParams(skip_header=True,delimiter=' ') + + r2 = q.execute('select * from %s' % tmpfile.name,input_params=overwriting_input_params) + + self.assertFalse(r2.error_occurred()) + self.assertEquals(len(r2.warnings),0) + self.assertEquals(len(r2.data),2) + self.assertEquals(r2.column_name_list,['a','b','c']) + self.assertEquals(r2.data,[(1,2,3),(4,5,6)]) + self.assertEquals(len(r2.data_loads),1) + self.assertEquals(r2.data_loads[0].filename,tmpfile.name) + + self.cleanup(tmpfile) + + def test_input_params_merge(self): + input_params = QInputParams() + + for k in input_params.__dict__.keys(): + setattr(input_params,k,'GARBAGE') + + merged_input_params = input_params.merged_with(QInputParams()) + + for k in merged_input_params.__dict__.keys(): + self.assertTrue(getattr(merged_input_params,k) != 'GARBAGE') + + for k in input_params.__dict__.keys(): + self.assertTrue(getattr(merged_input_params,k) != 'GARBAGE') def suite(): tl = unittest.TestLoader() -- cgit v1.2.3