diff options
-rw-r--r-- | buffered-reader/src/lib.rs | 22 | ||||
-rw-r--r-- | openpgp/src/parse/parse.rs | 71 |
2 files changed, 81 insertions, 12 deletions
diff --git a/buffered-reader/src/lib.rs b/buffered-reader/src/lib.rs index 951b8c71..7f082f72 100644 --- a/buffered-reader/src/lib.rs +++ b/buffered-reader/src/lib.rs @@ -152,6 +152,24 @@ pub trait BufferedReader : io::Read + fmt::Debug { return Ok(data); } + /// Like steal_eof, but instead of returning the data, the data is + /// discarded. + fn drop_eof(&mut self) -> Result<(), std::io::Error> { + loop { + match self.data_consume(DEFAULT_BUF_SIZE) { + Ok(ref buffer) => + if buffer.len() < DEFAULT_BUF_SIZE { + // EOF. + break; + }, + Err(err) => + return Err(err), + } + } + + Ok(()) + } + fn into_inner<'a>(self: Box<Self>) -> Option<Box<BufferedReader + 'a>> where Self: 'a; } @@ -234,6 +252,10 @@ impl <'a> BufferedReader for Box<BufferedReader + 'a> { return self.as_mut().steal_eof(); } + fn drop_eof(&mut self) -> Result<(), std::io::Error> { + return self.as_mut().drop_eof(); + } + fn into_inner<'b>(self: Box<Self>) -> Option<Box<BufferedReader + 'b>> where Self: 'b { // Strip the outer box. diff --git a/openpgp/src/parse/parse.rs b/openpgp/src/parse/parse.rs index 103f06d7..381ba8f6 100644 --- a/openpgp/src/parse/parse.rs +++ b/openpgp/src/parse/parse.rs @@ -380,13 +380,14 @@ fn literal_parser_test () { assert_eq!(header.length, BodyLength::Full(18)); let mut pp = literal_parser(bio).unwrap(); + let content = pp.steal_eof().unwrap(); let p = pp.finish(); // eprintln!("{:?}", p); if let &Packet::Literal(ref p) = p { assert_eq!(p.format, 'b' as u8); assert_eq!(p.filename.as_ref().unwrap()[..], b"foobar"[..]); assert_eq!(p.date, 1507458744); - assert_eq!(p.common.content, Some(b"FOOBAR"[..].to_vec())); + assert_eq!(content, b"FOOBAR"); } else { unreachable!(); } @@ -404,6 +405,7 @@ fn literal_parser_test () { let bio2 = BufferedReaderPartialBodyFilter::new(bio, l); let mut pp = literal_parser(bio2).unwrap(); + let content = pp.steal_eof().unwrap(); let p = pp.finish(); if let &Packet::Literal(ref p) = p { assert_eq!(p.format, 't' as u8); @@ -413,7 +415,7 @@ fn literal_parser_test () { let expected = bytes!("a-cypherpunks-manifesto.txt"); - assert_eq!(p.common.content, Some(expected.to_vec())); + assert_eq!(&content[..], &expected[..]); } else { unreachable!(); } @@ -518,19 +520,20 @@ fn compressed_data_parser_test () { } // ppo should be the literal data packet. - assert!(ppo.is_some()); + let mut pp = ppo.unwrap(); // It is a child. assert_eq!(relative_position, 1); - let (literal, ppo, _relative_position) - = ppo.unwrap().recurse().unwrap(); + let content = pp.steal_eof().unwrap(); + + let (literal, ppo, _relative_position) = pp.recurse().unwrap(); if let Packet::Literal(literal) = literal { assert_eq!(literal.filename, None); assert_eq!(literal.format, 'b' as u8); assert_eq!(literal.date, 1509219866); - assert_eq!(literal.common.content, Some(expected.to_vec())); + assert_eq!(content, expected.to_vec()); } else { unreachable!(); } @@ -549,6 +552,10 @@ struct PacketParserBuilderSettings { // 255. Moreover, if it is too large, then a read from the // pipeline will blow the stack. max_recursion_depth: u8, + + // Whether a packet's contents should be buffered or dropped when + // the next packet is retrieved. + buffer_unread_content: bool, } pub struct PacketParserBuilder<R: BufferedReader> { @@ -560,6 +567,7 @@ pub struct PacketParserBuilder<R: BufferedReader> { const PACKET_PARSER_DEFAULTS : PacketParserBuilderSettings = PacketParserBuilderSettings { max_recursion_depth: MAX_RECURSION_DEPTH, + buffer_unread_content: false, }; impl<R: BufferedReader> PacketParserBuilder<R> { @@ -577,6 +585,18 @@ impl<R: BufferedReader> PacketParserBuilder<R> { self } + pub fn buffer_unread_content(mut self) + -> PacketParserBuilder<R> { + self.settings.buffer_unread_content = true; + self + } + + pub fn drop_unread_content(mut self) + -> PacketParserBuilder<R> { + self.settings.buffer_unread_content = false; + self + } + pub fn finalize<'a>(self) -> Result<Option<PacketParser<'a>>, std::io::Error> where Self: 'a { // Parse the first packet. @@ -641,8 +661,6 @@ pub struct PacketParser<'a> { // run-time!). reader: Box<BufferedReader + 'a>, - settings: PacketParserBuilderSettings, - // Whether the caller read the packets content. If so, then we // can't recurse, because we're missing some of the packet! content_was_read: bool, @@ -653,6 +671,8 @@ pub struct PacketParser<'a> { // The packet that is being parsed. pub packet: Packet, + + settings: PacketParserBuilderSettings, } impl <'a> std::fmt::Debug for PacketParser<'a> { @@ -660,6 +680,7 @@ impl <'a> std::fmt::Debug for PacketParser<'a> { f.debug_struct("PacketParser") .field("reader", &self.reader) .field("packet", &self.packet) + .field("content_was_read", &self.content_was_read) .field("recursion_depth", &self.recursion_depth) .field("settings", &self.settings) .finish() @@ -874,7 +895,8 @@ impl <'a> PacketParser<'a> { self.next() } - pub fn finish<'b>(&'b mut self) -> &'b Packet { + pub fn buffer_unread_content<'b>(&'b mut self) + -> Result<&'b [u8], io::Error> { let mut rest = self.reader.steal_eof().unwrap(); if rest.len() > 0 { if let Some(mut content) = self.packet.content.take() { @@ -883,6 +905,22 @@ impl <'a> PacketParser<'a> { } else { self.packet.content = Some(rest); } + + Ok(&self.packet.content.as_ref().unwrap()[..]) + } else { + Ok(&b""[..]) + } + } + + pub fn finish<'b>(&'b mut self) -> &'b Packet { + if self.settings.buffer_unread_content { + if let Err(_err) = self.buffer_unread_content() { + // XXX: We should propagate the error. + unimplemented!(); + } + } else { + + self.reader.drop_eof().unwrap(); } return &mut self.packet; @@ -957,6 +995,11 @@ impl<'a> BufferedReader for PacketParser<'a> { return self.reader.steal_eof(); } + fn drop_eof(&mut self) -> Result<(), io::Error> { + self.content_was_read = true; + return self.reader.drop_eof(); + } + fn into_inner<'b>(self: Box<Self>) -> Option<Box<BufferedReader + 'b>> where Self: 'b { None @@ -1099,7 +1142,9 @@ impl Message { pub fn deserialize<R: BufferedReader>(bio: R) -> Result<Message, std::io::Error> { - PacketParserBuilder::from_buffered_reader(bio)?.deserialize() + PacketParserBuilder::from_buffered_reader(bio)? + .buffer_unread_content() + .deserialize() } pub fn from_reader<R: io::Read>(reader: R) @@ -1243,8 +1288,10 @@ mod message_test { // literal packet. When we read some of the compressed // packet, we expect recurse() to not recurse. - let ppo = PacketParser::from_file( - path_to("compressed-data-algo-1.gpg")).unwrap(); + let ppo = PacketParserBuilder::from_file( + path_to("compressed-data-algo-1.gpg")).unwrap() + .buffer_unread_content() + .finalize().unwrap(); let mut pp = ppo.unwrap(); if let Packet::CompressedData(_) = pp.packet { |