summaryrefslogtreecommitdiffstats
diff options
context:
space:
mode:
authorHarel Ben-Attia <harelba@gmail.com>2019-12-21 15:11:48 +0200
committerHarel Ben-Attia <harelba@gmail.com>2019-12-21 15:11:48 +0200
commita603ab65c560bd4deec97c7068c47740e1bce7e5 (patch)
tree4e8207f616262e45283e23b06bac778a2f630973
parent5c96ad2904b2405e4fada5fdbdcacb8d2f63eed2 (diff)
Fix output header issue for multi-file tables https://github.com/harelba/q/issues/2122.0.92.0.8
-rwxr-xr-xbin/q47
-rwxr-xr-xtest/test-suite109
2 files changed, 142 insertions, 14 deletions
diff --git a/bin/q b/bin/q
index fbd5879..56c6fe0 100755
--- a/bin/q
+++ b/bin/q
@@ -476,16 +476,18 @@ class TableColumnInferer(object):
self.rows = []
self.skip_header = skip_header
self.header_row = None
+ self.header_row_filename = None
self.expected_column_count = expected_column_count
self.input_delimiter = input_delimiter
self.disable_column_type_detection = disable_column_type_detection
- def analyze(self, col_vals):
+ def analyze(self, filename, col_vals):
if self.inferred:
raise Exception("Already inferred columns")
if self.skip_header and self.header_row is None:
self.header_row = col_vals
+ self.header_row_filename = filename
else:
self.rows.append(col_vals)
@@ -905,17 +907,36 @@ class TableCreator(object):
mfs = MaterializedFileState(filename,f,self.encoding,dialect,is_stdin)
self.materialized_file_dict[filename] = mfs
+ def _should_skip_extra_headers(self, filenumber, filename, mfs, col_vals):
+ if not self.skip_header:
+ return False
+
+ if filenumber == 0:
+ return False
+
+ header_already_exists = self.column_inferer.header_row is not None
+
+ is_extra_header = self.skip_header and mfs.lines_read == 1 and header_already_exists
+
+ if is_extra_header:
+ if tuple(self.column_inferer.header_row) != tuple(col_vals):
+ raise BadHeaderException("Extra header {} in file {} mismatches original header {} from file {}. Table name is {}".format(",".join(col_vals),mfs.filename,",".join(self.column_inferer.header_row),self.column_inferer.header_row_filename,self.filenames_str))
+
+ return is_extra_header
+
def _populate(self,dialect,stop_after_analysis=False):
total_data_lines_read = 0
# For each match
- for filename in self.materialized_file_list:
+ for filenumber,filename in enumerate(self.materialized_file_list):
mfs = self.materialized_file_dict[filename]
try:
try:
for col_vals in mfs.read_file_using_csv():
- self._insert_row(col_vals)
+ if self._should_skip_extra_headers(filenumber,filename,mfs,col_vals):
+ continue
+ self._insert_row(filename, col_vals)
if stop_after_analysis and self.column_inferer.inferred:
return
if mfs.lines_read == 0 and self.skip_header:
@@ -937,7 +958,7 @@ class TableCreator(object):
if not self.table_created:
self.column_inferer.force_analysis()
- self._do_create_table()
+ self._do_create_table(filename)
if total_data_lines_read == 0:
@@ -960,20 +981,20 @@ class TableCreator(object):
self.state = TableCreatorState.FULLY_READ
return
- def _flush_pre_creation_rows(self):
+ def _flush_pre_creation_rows(self, filename):
for i, col_vals in enumerate(self.pre_creation_rows):
if self.skip_header and i == 0:
# skip header line
continue
- self._insert_row(col_vals)
+ self._insert_row(filename, col_vals)
self._flush_inserts()
self.pre_creation_rows = []
- def _insert_row(self, col_vals):
+ def _insert_row(self, filename, col_vals):
# If table has not been created yet
if not self.table_created:
# Try to create it along with another "example" line of data
- self.try_to_create_table(col_vals)
+ self.try_to_create_table(filename, col_vals)
# If the table is still not created, then we don't have enough data, just
# store the data and return
@@ -1069,19 +1090,19 @@ class TableCreator(object):
# print self.db.execute_and_fetch(self.db.generate_end_transaction())
self.buffered_inserts = []
- def try_to_create_table(self, col_vals):
+ def try_to_create_table(self, filename, col_vals):
if self.table_created:
raise Exception('Table is already created')
# Add that line to the column inferer
- result = self.column_inferer.analyze(col_vals)
+ result = self.column_inferer.analyze(filename, col_vals)
# If inferer succeeded,
if result:
- self._do_create_table()
+ self._do_create_table(filename)
else:
pass # We don't have enough information for creating the table yet
- def _do_create_table(self):
+ def _do_create_table(self,filename):
# Then generate a temp table name
self.table_name = self.db.generate_temp_table_name()
# Get the column definition dict from the inferer
@@ -1101,7 +1122,7 @@ class TableCreator(object):
self.db.execute_and_fetch(create_table_stmt)
# Mark the table as created
self.table_created = True
- self._flush_pre_creation_rows()
+ self._flush_pre_creation_rows(filename)
def drop_table(self):
if self.table_created:
diff --git a/test/test-suite b/test/test-suite
index e17afcd..bc7fc37 100755
--- a/test/test-suite
+++ b/test/test-suite
@@ -93,6 +93,9 @@ sample_data_with_empty_string_no_header = six.b("\n").join(
sample_data_with_header = header_row + six.b("\n") + sample_data_no_header
sample_data_with_missing_header_names = six.b("name,value1\n") + sample_data_no_header
+def generate_sample_data_with_header(header):
+ return header + six.b("\n") + sample_data_no_header
+
sample_quoted_data = six.b('''non_quoted regular_double_quoted double_double_quoted escaped_double_quoted multiline_double_double_quoted multiline_escaped_double_quoted
control-value-1 "control-value-2" control-value-3 "control-value-4" control-value-5 "control-value-6"
non-quoted-value "this is a quoted value" "this is a ""double double"" quoted value" "this is an escaped \\"quoted value\\"" "this is a double double quoted ""multiline
@@ -1422,6 +1425,109 @@ class BasicTests(AbstractQTestCase):
self.cleanup(tmpfile)
+class MultiHeaderTests(AbstractQTestCase):
+ def test_output_header_when_multiple_input_headers_exist(self):
+ TMPFILE_COUNT = 5
+ tmpfiles = [self.create_file_with_data(sample_data_with_header) for x in range(TMPFILE_COUNT)]
+
+ tmpfilenames = "+".join(map(lambda x:x.name, tmpfiles))
+
+ cmd = '../bin/q -d , "select name,value1,value2 from %s order by name" -H -O' % tmpfilenames
+ retcode, o, e = run_command(cmd)
+
+ self.assertEqual(retcode, 0)
+ self.assertEqual(len(o), TMPFILE_COUNT*3+1)
+ self.assertEqual(o[0], six.b("name,value1,value2"))
+
+ for i in range (TMPFILE_COUNT):
+ self.assertEqual(o[1+i],sample_data_rows[0])
+ for i in range (TMPFILE_COUNT):
+ self.assertEqual(o[TMPFILE_COUNT+1+i],sample_data_rows[1])
+ for i in range (TMPFILE_COUNT):
+ self.assertEqual(o[TMPFILE_COUNT*2+1+i],sample_data_rows[2])
+
+ for oi in o[1:]:
+ self.assertTrue(six.b('name') not in oi)
+
+ for i in range(TMPFILE_COUNT):
+ self.cleanup(tmpfiles[i])
+
+ def test_output_header_when_extra_header_column_names_are_different(self):
+ tmpfile1 = self.create_file_with_data(sample_data_with_header)
+ tmpfile2 = self.create_file_with_data(generate_sample_data_with_header(six.b('othername,value1,value2')))
+
+ cmd = '../bin/q -d , "select name,value1,value2 from %s+%s order by name" -H -O' % (tmpfile1.name,tmpfile2.name)
+ retcode, o, e = run_command(cmd)
+
+ self.assertEqual(retcode, 35)
+ self.assertEqual(len(o), 0)
+ self.assertEqual(len(e), 1)
+ self.assertTrue(e[0].startswith(six.b("Bad header row:")))
+
+ self.cleanup(tmpfile1)
+ self.cleanup(tmpfile2)
+
+ def test_output_header_when_extra_header_has_different_number_of_columns(self):
+ tmpfile1 = self.create_file_with_data(sample_data_with_header)
+ tmpfile2 = self.create_file_with_data(generate_sample_data_with_header(six.b('name,value1')))
+
+ cmd = '../bin/q -d , "select name,value1,value2 from %s+%s order by name" -H -O' % (tmpfile1.name,tmpfile2.name)
+ retcode, o, e = run_command(cmd)
+
+ self.assertEqual(retcode, 35)
+ self.assertEqual(len(o), 0)
+ self.assertEqual(len(e), 1)
+ self.assertTrue(e[0].startswith(six.b("Bad header row:")))
+
+ self.cleanup(tmpfile1)
+ self.cleanup(tmpfile2)
+
+ def test_output_header_when_extra_header_has_different_number_of_columns2(self):
+ original_header = header_row
+ tmpfile1 = self.create_file_with_data(sample_data_with_header)
+ different_header = six.b('name,value1,value2,value3')
+ tmpfile2 = self.create_file_with_data(generate_sample_data_with_header(different_header))
+
+ SELECT_table_name = '%s+%s' % (tmpfile1.name,tmpfile2.name)
+ cmd = '../bin/q -d , "select name,value1,value2 from %s order by name" -H -O' % (SELECT_table_name)
+ retcode, o, e = run_command(cmd)
+
+ self.assertEqual(retcode, 35)
+ self.assertEqual(len(o), 0)
+ self.assertEqual(len(e), 1)
+ expected_message = six.b('Bad header row: Extra header %s in file %s mismatches original header %s from file %s. Table name is %s') % \
+ (different_header,six.b(tmpfile2.name),original_header,six.b(tmpfile1.name),six.b(SELECT_table_name))
+
+ self.assertEqual(e[0],expected_message)
+
+ self.cleanup(tmpfile1)
+ self.cleanup(tmpfile2)
+
+ # Not the best behavior, this means that if the first file in additional files contains exactly the
+ # same content as the original header, then q would skip this line instead of failing.
+ # Extremely rare case, and for any table with numeric values, this is not an issue, since column names
+ # cannot be numbers.
+ def test_output_header_when_additional_files_dont_have_a_header(self):
+ original_header = header_row
+ tmpfile1 = self.create_file_with_data(sample_data_with_header)
+ tmpfile2 = self.create_file_with_data(sample_data_no_header)
+
+ SELECT_table_name = '%s+%s' % (tmpfile1.name,tmpfile2.name)
+ cmd = '../bin/q -d , "select name,value1,value2 from %s order by name" -H -O' % (SELECT_table_name)
+ retcode, o, e = run_command(cmd)
+
+ self.assertEqual(retcode, 35)
+ self.assertEqual(len(o), 0)
+ self.assertEqual(len(e), 1)
+ expected_message = six.b('Bad header row: Extra header %s in file %s mismatches original header %s from file %s. Table name is %s') % \
+ (sample_data_rows[0],six.b(tmpfile2.name),original_header,six.b(tmpfile1.name),six.b(SELECT_table_name))
+
+ self.assertEqual(e[0],expected_message)
+
+ self.cleanup(tmpfile1)
+ self.cleanup(tmpfile2)
+
+
class ParsingModeTests(AbstractQTestCase):
def test_strict_mode_column_count_mismatch_error(self):
@@ -2351,7 +2457,8 @@ def suite():
formatting = tl.loadTestsFromTestCase(FormattingTests)
basic_module_stuff = tl.loadTestsFromTestCase(BasicModuleTests)
save_db_to_disk_tests = tl.loadTestsFromTestCase(SaveDbToDiskTests)
- return unittest.TestSuite([basic_module_stuff, basic_stuff, parsing_mode, sql, formatting,save_db_to_disk_tests])
+ multi_header_tests = tl.loadTestsFromTestCase(MultiHeaderTests)
+ return unittest.TestSuite([basic_module_stuff, basic_stuff, parsing_mode, sql, formatting,save_db_to_disk_tests,multi_header_tests])
if __name__ == '__main__':
if len(sys.argv) > 1: