summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHarel Ben-Attia <harelba@gmail.com>2014-11-15 19:41:47 -0500
committerHarel Ben-Attia <harelba@gmail.com>2014-11-15 19:41:47 -0500
commit08cf2bb897ee7ce4ebac22b4ac088a8a456f086f (patch)
treefd192794699fa3eabb80af33c2ba808d1ba0eb18
parentfaf39043dbe559d97428e193e13c38ec61c1b0a1 (diff)
Allow different input params for each loaded file + loading file manually + fixed modeling of query encoding + tests
-rwxr-xr-xbin/q216
-rwxr-xr-xtest/test-suite102
2 files changed, 222 insertions, 96 deletions
diff --git a/bin/q b/bin/q
index e579e55..e7e5fd6 100755
--- a/bin/q
+++ b/bin/q
@@ -195,6 +195,14 @@ class BadHeaderException(Exception):
def __str(self):
return repr(self.msg)
+class EncodedQueryException(Exception):
+
+ def __init__(self, msg):
+ self.msg = msg
+
+ def __str(self):
+ return repr(self.msg)
+
class CannotUnzipStdInException(Exception):
@@ -836,6 +844,7 @@ class QError(object):
self.exception = exception
self.msg = msg
self.errorcode = errorcode
+ self.traceback = traceback.format_exc()
class QDataLoad(object):
def __init__(self,filename,start_time,end_time):
@@ -851,7 +860,7 @@ class QDataLoad(object):
__repr__ = __str__
class QOutput(object):
- def __init__(self,data,column_name_list,warnings,error,data_loads):
+ def __init__(self,data=None,column_name_list=None,warnings=[],error=None,data_loads=[]):
self.data = data
self.column_name_list = column_name_list
self.warnings = warnings
@@ -879,23 +888,12 @@ class QOutput(object):
return "QOutput<%s>" % ",".join(s)
__repr__ = __str__
-class QTextAsData(object):
- def __init__(self,
- skip_header=False,
- delimiter=' ',
- input_encoding='UTF-8',
- gzipped_input=False,
- parsing_mode='relaxed',
- expected_column_count=None,
- keep_leading_whitespace_in_values=False,
- disable_double_double_quoting=False,
- disable_escaped_double_quoting=False,
- input_quoting_mode='minimal',
-
- query_encoding=locale.getpreferredencoding(),
- debug=False,
- analyze_only=False
- ):
+class QInputParams(object):
+ def __init__(self,skip_header=False,
+ delimiter=' ',input_encoding='UTF-8',gzipped_input=False,parsing_mode='relaxed',
+ expected_column_count=None,keep_leading_whitespace_in_values=False,
+ disable_double_double_quoting=False,disable_escaped_double_quoting=False,
+ input_quoting_mode='minimal',stdin_file=None,stdin_filename='-'):
self.skip_header = skip_header
self.delimiter = delimiter
self.input_encoding = input_encoding
@@ -907,14 +905,25 @@ class QTextAsData(object):
self.disable_escaped_double_quoting = disable_escaped_double_quoting
self.input_quoting_mode = input_quoting_mode
- self.query_encoding = query_encoding
- self.debug = debug
- self.analyze_only = analyze_only
+ def merged_with(self,input_params):
+ params = QInputParams(**self.__dict__)
+ if input_params is not None:
+ params.__dict__.update(**input_params.__dict__)
+ return params
+
+ def __str__(self):
+ return "QInputParams<%s>" % str(self.__dict__)
+
+ def __repr__(self):
+ return "QInputParams(...)"
- self.q_dialect = self.determine_proper_dialect()
- # TODO Isolate dialects of each Q instance
- self.dialect_name = 'q'
- csv.register_dialect(self.dialect_name, **self.q_dialect)
+class QTextAsData(object):
+ def __init__(self,default_input_params=QInputParams(),
+ analyze_only=False
+ ):
+ self.default_input_params = default_input_params
+
+ self.analyze_only = analyze_only
self.table_creators = {}
@@ -928,56 +937,67 @@ class QTextAsData(object):
# ourselves instead of letting the csv module try to identify the types
'none' : csv.QUOTE_NONE }
- def determine_proper_dialect(self):
+ def determine_proper_dialect(self,input_params):
- input_quoting_mode_csv_numeral = QTextAsData.input_quoting_modes[self.input_quoting_mode]
+ input_quoting_mode_csv_numeral = QTextAsData.input_quoting_modes[input_params.input_quoting_mode]
- if self.keep_leading_whitespace_in_values:
+ if input_params.keep_leading_whitespace_in_values:
skip_initial_space = False
else:
skip_initial_space = True
dialect = {'skipinitialspace': skip_initial_space,
- 'delimiter': self.delimiter, 'quotechar': '"' }
+ 'delimiter': input_params.delimiter, 'quotechar': '"' }
dialect['quoting'] = input_quoting_mode_csv_numeral
- dialect['doublequote'] = self.disable_double_double_quoting
+ dialect['doublequote'] = input_params.disable_double_double_quoting
- if self.disable_escaped_double_quoting:
+ if input_params.disable_escaped_double_quoting:
dialect['escapechar'] = '\\'
return dialect
- def ensure_data_is_loaded(self,query_str,stdin_file,stdin_filename='-'):
- data_loads = []
+ def get_dialect_id(self,filename):
+ return 'q_dialect_%s' % filename
- # Create SQL statment
- sql_object = Sql('%s' % query_str)
+ def _load_data(self,filename,input_params=QInputParams(),stdin_file=None,stdin_filename='-'):
+ start_time = time.time()
+
+ q_dialect = self.determine_proper_dialect(input_params)
+ dialect_id = self.get_dialect_id(filename)
+ csv.register_dialect(dialect_id, **q_dialect)
# Create a line splitter
- line_splitter = LineSplitter(self.delimiter, self.expected_column_count)
+ line_splitter = LineSplitter(input_params.delimiter, input_params.expected_column_count)
- # Get each "table name" which is actually the file name
- for filename in sql_object.qtable_names:
- start_time = time.time()
+ # reuse already loaded data, except for stdin file data (stdin file data will always
+ # be reloaded and overwritten)
+ if filename in self.table_creators.keys() and filename != stdin_filename:
+ return None
- # reuse already loaded data, except for stdin file data (stdin file data will always
- # be reloaded and overwritten)
- if filename in self.table_creators.keys() and filename != stdin_filename:
- continue
+ # Create the matching database table and populate it
+ table_creator = TableCreator(
+ self.db, filename, line_splitter, input_params.skip_header, input_params.gzipped_input, input_params.input_encoding,
+ mode=input_params.parsing_mode, expected_column_count=input_params.expected_column_count,
+ input_delimiter=input_params.delimiter,stdin_file = stdin_file,stdin_filename = stdin_filename)
+ table_creator.populate(dialect_id,self.analyze_only)
+ self.table_creators[filename] = table_creator
+
+ return QDataLoad(filename,start_time,time.time())
+
+ def load_data(self,filename,input_params=QInputParams()):
+ self._load_data(filename,input_params)
- # Create the matching database table and populate it
- table_creator = TableCreator(
- self.db, filename, line_splitter, self.skip_header, self.gzipped_input, self.input_encoding,
- mode=self.parsing_mode, expected_column_count=self.expected_column_count,
- input_delimiter=self.delimiter,stdin_file = stdin_file,stdin_filename = stdin_filename)
- table_creator.populate(self.dialect_name,self.analyze_only)
- self.table_creators[filename] = table_creator
+ def ensure_data_is_loaded(self,query_str,input_params,stdin_file,stdin_filename='-'):
+ data_loads = []
- data_loads.append(QDataLoad(filename,start_time,time.time()))
+ # Create SQL statment
+ sql_object = Sql('%s' % query_str)
- if self.debug:
- print >>sys.stderr, "TIMING - populate time is %4.3f" % (
- time.time() - start_time)
+ # Get each "table name" which is actually the file name
+ for filename in sql_object.qtable_names:
+ data_load = self._load_data(filename,input_params,stdin_file=stdin_file,stdin_filename=stdin_filename)
+ if data_load is not None:
+ data_loads.append(data_load)
return data_loads
@@ -985,14 +1005,24 @@ class QTextAsData(object):
for filename in sql_object.qtable_names:
sql_object.set_effective_table_name(filename,self.table_creators[filename].table_name)
- def execute(self,query_str,stdin_file=None,stdin_filename='-'):
+ def execute(self,query_str,input_params=None,stdin_file=None,stdin_filename='-'):
warnings = []
error = None
data_loads = []
db_results_obj = None
+ effective_input_params = self.default_input_params.merged_with(input_params)
+
+ if type(query_str) != unicode:
+ try:
+ # Hueristic attempt to auto convert the query to unicode before failing
+ query_str = query_str.decode('utf-8')
+ except:
+ error = QError(EncodedQueryException(),"Query should be in unicode. Please make sure to provide a unicode literal string or decode it using proper the character encoding.",91)
+ return QOutput(error = error)
+
try:
- data_loads += self.ensure_data_is_loaded(query_str,stdin_file=stdin_file,stdin_filename=stdin_filename)
+ data_loads += self.ensure_data_is_loaded(query_str,effective_input_params,stdin_file=stdin_file,stdin_filename=stdin_filename)
# Create SQL statment
sql_object = Sql(query_str)
@@ -1006,7 +1036,7 @@ class QTextAsData(object):
for k in column_names:
column_type = table_creator.column_inferer.get_column_dict()[k]
print " `%s` - %s" % (k, self.db.type_names[column_type].lower())
- return QOutput(data = None,column_name_list = None,warnings = warnings, error = None, data_loads = [])
+ return QOutput()
# Execute the query and fetch the data
db_results_obj = sql_object.execute_and_fetch(self.db)
@@ -1018,7 +1048,7 @@ class QTextAsData(object):
except sqlite3.OperationalError, e:
msg = str(e)
error = QError(e,"query error: %s" % msg,1)
- if "no such column" in msg and self.skip_header:
+ if "no such column" in msg and effective_input_params.skip_header:
warnings.append(QWarning(e,'Warning - There seems to be a "no such column" error, and -H (header line) exists. Please make sure that you are using the column names from the header line and not the default (cXX) column names'))
except ColumnCountMismatchException, e:
error = QError(e,e.msg,2)
@@ -1037,7 +1067,7 @@ class QTextAsData(object):
except KeyboardInterrupt,e:
warnings.append(QWarning(e,"Interrupted"))
except Exception, e:
- error = QError(e,str(e),199)
+ error = QError(e,repr(e),199)
if db_results_obj is not None:
return QOutput(
@@ -1047,7 +1077,7 @@ class QTextAsData(object):
error = error,
data_loads = data_loads)
else:
- return QOutput(data = None, column_name_list = None,warnings = warnings,error = error , data_loads = data_loads)
+ return QOutput(warnings = warnings,error = error , data_loads = data_loads)
def unload(self):
@@ -1083,25 +1113,36 @@ def quote_nonnumeric_func(output_delimiter,v):
def quote_all_func(output_delimiter,v):
return '"%s"' % (v)
-class QOutputPrinter(object):
- output_quoting_modes = { 'minimal' : quote_minimal_func,
- 'all' : quote_all_func,
- 'nonnumeric' : quote_nonnumeric_func,
- 'none' : quote_none_func }
-
+class QOutputParams(object):
def __init__(self,
- delimiter=' ',
- beautify=False,
- output_quoting_mode='minimal',
- formatting=None,
- output_header=False):
+ delimiter=' ',
+ beautify=False,
+ output_quoting_mode='minimal',
+ formatting=None,
+ output_header=False):
self.delimiter = delimiter
self.beautify = beautify
self.output_quoting_mode = output_quoting_mode
self.formatting = formatting
self.output_header = output_header
- self.output_field_quoting_func = QOutputPrinter.output_quoting_modes[self.output_quoting_mode]
+ def __str__(self):
+ return "QOutputParams<%s>" % str(self.__dict__)
+
+ def __repr__(self):
+ return "QOutputParams(...)"
+
+
+class QOutputPrinter(object):
+ output_quoting_modes = { 'minimal' : quote_minimal_func,
+ 'all' : quote_all_func,
+ 'nonnumeric' : quote_nonnumeric_func,
+ 'none' : quote_none_func }
+
+ def __init__(self,output_params):
+ self.output_params = output_params
+
+ self.output_field_quoting_func = QOutputPrinter.output_quoting_modes[output_params.output_quoting_mode]
def print_output(self,f,results):
@@ -1118,17 +1159,17 @@ class QOutputPrinter(object):
return
# If the user requested beautifying the output
- if self.beautify:
- max_lengths = determine_max_col_lengths(data,self.output_field_quoting_func,self.delimiter)
+ if self.output_params.beautify:
+ max_lengths = determine_max_col_lengths(data,self.output_field_quoting_func,self.output_params.delimiter)
- if self.formatting:
+ if self.output_params.formatting:
formatting_dict = dict(
- [(x.split("=")[0], x.split("=")[1]) for x in self.formatting.split(",")])
+ [(x.split("=")[0], x.split("=")[1]) for x in self.output_params.formatting.split(",")])
else:
formatting_dict = None
try:
- if self.output_header and results.column_name_list is not None:
+ if self.output_params.output_header and results.column_name_list is not None:
data.insert(0,results.column_name_list)
for rownum, row in enumerate(data):
row_str = []
@@ -1136,17 +1177,17 @@ class QOutputPrinter(object):
if formatting_dict is not None and str(i + 1) in formatting_dict.keys():
fmt_str = formatting_dict[str(i + 1)]
else:
- if self.beautify:
+ if self.output_params.beautify:
fmt_str = "%%-%ss" % max_lengths[i]
else:
fmt_str = "%s"
if col is not None:
- row_str.append(fmt_str % self.output_field_quoting_func(self.delimiter,col))
+ row_str.append(fmt_str % self.output_field_quoting_func(self.output_params.delimiter,col))
else:
row_str.append(fmt_str % "")
- f.write(self.delimiter.join(row_str) + "\n")
+ f.write(self.output_params.delimiter.join(row_str) + "\n")
except (UnicodeEncodeError, UnicodeError), e:
print >>sys.stderr, "Cannot encode data. Error:%s" % e
sys.exit(3)
@@ -1382,8 +1423,7 @@ def run_standalone():
# (since no input delimiter means any whitespace)
options.output_delimiter = " "
- q_query = QTextAsData(
- skip_header=options.skip_header,
+ default_input_params = QInputParams(skip_header=options.skip_header,
delimiter=options.delimiter,
input_encoding=options.encoding,
gzipped_input=options.gzipped,
@@ -1392,21 +1432,23 @@ def run_standalone():
keep_leading_whitespace_in_values=options.keep_leading_whitespace_in_values,
disable_double_double_quoting=options.disable_double_double_quoting,
disable_escaped_double_quoting=options.disable_escaped_double_quoting,
- input_quoting_mode=options.input_quoting_mode,
- query_encoding=options.query_encoding,
- debug=DEBUG,
- analyze_only=options.analyze_only)
+ input_quoting_mode=options.input_quoting_mode)
+
+ q_query = QTextAsData(
+ default_input_params=default_input_params,analyze_only=options.analyze_only)
q_output = q_query.execute(query_str,stdin_file=sys.stdin)
q_query.unload()
- q_output_printer = QOutputPrinter(
+ output_params = QOutputParams(
delimiter=options.output_delimiter,
beautify=options.beautify,
output_quoting_mode=options.output_quoting_mode,
formatting=options.formatting,
output_header=options.output_header)
+ q_output_printer = QOutputPrinter(output_params)
+
try:
q_output_printer.print_output(STDOUT,q_output)
diff --git a/test/test-suite b/test/test-suite
index 39f82ce..cc3ecca 100755
--- a/test/test-suite
+++ b/test/test-suite
@@ -21,7 +21,7 @@ import locale
import pprint
sys.path.append(os.path.join(os.path.abspath(os.path.dirname(sys.argv[0])),'..','bin'))
-from qtextasdata import QTextAsData,QOutput,QOutputPrinter
+from qtextasdata import QTextAsData,QOutput,QOutputPrinter,QInputParams
# q uses this encoding as the default output encoding. Some of the tests use it in order to
# make sure that the output is correctly encoded
@@ -422,6 +422,7 @@ class BasicTests(AbstractQTestCase):
def test_generated_column_name_warning_when_header_line_exists(self):
tmpfile = self.create_file_with_data(sample_data_with_header)
cmd = '../bin/q -d , "select c3 from %s" -H' % tmpfile.name
+
retcode, o, e = run_command(cmd)
self.assertNotEquals(retcode, 0)
@@ -1347,7 +1348,7 @@ class ParsingModeTests(AbstractQTestCase):
tmpfile = self.create_file_with_data(uneven_ls_output)
cmd = '../bin/q -m relaxed "select count(*) from %s" -A' % tmpfile.name
retcode, o, e = run_command(cmd)
-
+
self.assertEquals(retcode, 0)
self.assertEquals(len(e), 0)
@@ -1586,7 +1587,7 @@ class BasicModuleTests(AbstractQTestCase):
def test_simple_query(self):
tmpfile = self.create_file_with_data("a b c\n1 2 3\n4 5 6")
- q = QTextAsData(skip_header=True,delimiter=' ')
+ q = QTextAsData(QInputParams(skip_header=True,delimiter=' '))
r = q.execute('select * from %s' % tmpfile.name)
self.assertFalse(r.error_occurred())
@@ -1602,7 +1603,7 @@ class BasicModuleTests(AbstractQTestCase):
def test_loaded_data_reuse(self):
tmpfile = self.create_file_with_data("a b c\n1 2 3\n4 5 6")
- q = QTextAsData(skip_header=True,delimiter=' ')
+ q = QTextAsData(QInputParams(skip_header=True,delimiter=' '))
r1 = q.execute('select * from %s' % tmpfile.name)
r2 = q.execute('select * from %s' % tmpfile.name)
@@ -1627,7 +1628,7 @@ class BasicModuleTests(AbstractQTestCase):
def test_stdin_injection(self):
tmpfile = self.create_file_with_data("a b c\n1 2 3\n4 5 6")
- q = QTextAsData(skip_header=True,delimiter=' ')
+ q = QTextAsData(QInputParams(skip_header=True,delimiter=' '))
r = q.execute('select * from -',stdin_file=file(tmpfile.name,'rb'))
self.assertFalse(r.error_occurred())
@@ -1643,7 +1644,7 @@ class BasicModuleTests(AbstractQTestCase):
def test_named_stdin_injection(self):
tmpfile = self.create_file_with_data("a b c\n1 2 3\n4 5 6")
- q = QTextAsData(skip_header=True,delimiter=' ')
+ q = QTextAsData(QInputParams(skip_header=True,delimiter=' '))
r = q.execute('select a from my_stdin_data',stdin_file=file(tmpfile.name,'rb'),stdin_filename='my_stdin_data')
self.assertFalse(r.error_occurred())
@@ -1660,7 +1661,7 @@ class BasicModuleTests(AbstractQTestCase):
tmpfile1 = self.create_file_with_data("a b c\n1 2 3\n4 5 6")
tmpfile2 = self.create_file_with_data("d e f\n7 8 9\n10 11 12")
- q = QTextAsData(skip_header=True,delimiter=' ')
+ q = QTextAsData(QInputParams(skip_header=True,delimiter=' '))
r1 = q.execute('select * from -',stdin_file=file(tmpfile1.name,'rb'))
self.assertFalse(r1.error_occurred())
@@ -1689,7 +1690,7 @@ class BasicModuleTests(AbstractQTestCase):
tmpfile1 = self.create_file_with_data("a b c\n1 2 3\n4 5 6")
tmpfile2 = self.create_file_with_data("d e f\n7 8 9\n10 11 12")
- q = QTextAsData(skip_header=True,delimiter=' ')
+ q = QTextAsData(QInputParams(skip_header=True,delimiter=' '))
r1 = q.execute('select * from my_stdin_data1',stdin_file=file(tmpfile1.name,'rb'),stdin_filename='my_stdin_data1')
self.assertFalse(r1.error_occurred())
@@ -1722,7 +1723,90 @@ class BasicModuleTests(AbstractQTestCase):
self.cleanup(tmpfile1)
self.cleanup(tmpfile2)
-
+
+ def test_different_input_params_for_different_files(self):
+ tmpfile1 = self.create_file_with_data("a b c\n1 2 3\n4 5 6")
+ tmpfile2 = self.create_file_with_data("7\t8\t9\n10\t11\t12")
+
+ q = QTextAsData(QInputParams(skip_header=True,delimiter=' '))
+
+ q.load_data(tmpfile1.name,QInputParams(skip_header=True,delimiter=' '))
+ q.load_data(tmpfile2.name,QInputParams(skip_header=False,delimiter='\t'))
+
+ r = q.execute('select aa.*,bb.* from %s aa join %s bb' % (tmpfile1.name,tmpfile2.name))
+
+ self.assertFalse(r.error_occurred())
+ self.assertEquals(len(r.warnings),0)
+ self.assertEquals(len(r.data),4)
+ self.assertEquals(r.column_name_list,['a','b','c','c1','c2','c3'])
+ self.assertEquals(r.data,[(1,2,3,7,8,9),(1,2,3,10,11,12),(4,5,6,7,8,9),(4,5,6,10,11,12)])
+ self.assertEquals(len(r.data_loads),0)
+
+ self.cleanup(tmpfile1)
+ self.cleanup(tmpfile2)
+
+ def test_different_input_params_for_different_files(self):
+ tmpfile1 = self.create_file_with_data("a b c\n1 2 3\n4 5 6")
+ tmpfile2 = self.create_file_with_data("7\t8\t9\n10\t11\t12")
+
+ q = QTextAsData()
+
+ q.load_data(tmpfile1.name,QInputParams(skip_header=True,delimiter=' '))
+ q.load_data(tmpfile2.name,QInputParams(skip_header=False,delimiter='\t'))
+
+ r = q.execute('select aa.*,bb.* from %s aa join %s bb' % (tmpfile1.name,tmpfile2.name))
+
+ self.assertFalse(r.error_occurred())
+ self.assertEquals(len(r.warnings),0)
+ self.assertEquals(len(r.data),4)
+ self.assertEquals(r.column_name_list,['a','b','c','c1','c2','c3'])
+ self.assertEquals(r.data,[(1,2,3,7,8,9),(1,2,3,10,11,12),(4,5,6,7,8,9),(4,5,6,10,11,12)])
+ self.assertEquals(len(r.data_loads),0)
+
+ self.cleanup(tmpfile1)
+ self.cleanup(tmpfile2)
+
+ def test_input_params_override(self):
+ tmpfile = self.create_file_with_data("a b c\n1 2 3\n4 5 6")
+
+ default_input_params = QInputParams()
+
+ for k in default_input_params.__dict__.keys():
+ setattr(default_input_params,k,'GARBAGE')
+
+ q = QTextAsData(default_input_params)
+
+ r = q.execute('select * from %s' % tmpfile.name)
+
+ self.assertTrue(r.error_occurred())
+
+ overwriting_input_params = QInputParams(skip_header=True,delimiter=' ')
+
+ r2 = q.execute('select * from %s' % tmpfile.name,input_params=overwriting_input_params)
+
+ self.assertFalse(r2.error_occurred())
+ self.assertEquals(len(r2.warnings),0)
+ self.assertEquals(len(r2.data),2)
+ self.assertEquals(r2.column_name_list,['a','b','c'])
+ self.assertEquals(r2.data,[(1,2,3),(4,5,6)])
+ self.assertEquals(len(r2.data_loads),1)
+ self.assertEquals(r2.data_loads[0].filename,tmpfile.name)
+
+ self.cleanup(tmpfile)
+
+ def test_input_params_merge(self):
+ input_params = QInputParams()
+
+ for k in input_params.__dict__.keys():
+ setattr(input_params,k,'GARBAGE')
+
+ merged_input_params = input_params.merged_with(QInputParams())
+
+ for k in merged_input_params.__dict__.keys():
+ self.assertTrue(getattr(merged_input_params,k) != 'GARBAGE')
+
+ for k in input_params.__dict__.keys():
+ self.assertTrue(getattr(merged_input_params,k) != 'GARBAGE')
def suite():
tl = unittest.TestLoader()