diff --git a/psoutils/src/compression.rs b/psoutils/src/compression.rs index 527a12e..bccfa46 100644 --- a/psoutils/src/compression.rs +++ b/psoutils/src/compression.rs @@ -1,5 +1,13 @@ use std::ffi::c_void; +use thiserror::Error; + +#[derive(Error, Debug)] +pub enum PrsCompressionError { + #[error("Error due to bad input data: {0}")] + BadData(String), +} + struct Context { bitpos: u8, forward_log: Vec, @@ -128,7 +136,7 @@ fn is_mem_equal(base: &[u8], offset1: isize, offset2: isize, length: usize) -> b } } -pub fn prs_compress(source: &[u8]) -> Box<[u8]> { +pub fn prs_compress(source: &[u8]) -> Result, PrsCompressionError> { let mut pc = Context::new(); let mut x: isize = 0; @@ -160,7 +168,15 @@ pub fn prs_compress(source: &[u8]) -> Box<[u8]> { } if lssize == 0 { - pc.raw_byte(source[x as usize]); + pc.raw_byte(match source.get(x as usize) { + Some(value) => *value, + None => { + return Err(PrsCompressionError::BadData(format!( + "tried to add raw byte from source at out-of-bounds index {}", + x + ))) + } + }); } else { pc.copy(lsoffset, lssize as u8); x += lssize - 1; @@ -169,7 +185,7 @@ pub fn prs_compress(source: &[u8]) -> Box<[u8]> { x += 1; } - pc.finish() + Ok(pc.finish()) } enum Next { @@ -198,7 +214,7 @@ impl<'a> ByteReader<'a> { } } -pub fn prs_decompress(source: &[u8]) -> Box<[u8]> { +pub fn prs_decompress(source: &[u8]) -> Result, PrsCompressionError> { let mut output = Vec::new(); let mut reader = ByteReader::new(source); let mut r3: i32; @@ -208,9 +224,19 @@ pub fn prs_decompress(source: &[u8]) -> Box<[u8]> { let mut flag: bool; let mut offset: i32; + // if you prs_compress a zero-length buffer, you get a 3-byte "compressed" result. + // therefore, 3 byte minimum input buffer is required to get any kind of "meaningful" + // decompression result back out + if source.len() < 3 { + return Err(PrsCompressionError::BadData(format!( + "Input data is too short: {} bytes", + source.len() + ))); + } + current_byte = match reader.next() { Next::Byte(byte) => byte, - Next::Eof() => return output.into_boxed_slice(), + Next::Eof() => return Ok(output.into_boxed_slice()), }; loop { @@ -218,7 +244,7 @@ pub fn prs_decompress(source: &[u8]) -> Box<[u8]> { if bitpos == 0 { current_byte = match reader.next() { Next::Byte(byte) => byte, - Next::Eof() => return output.into_boxed_slice(), + Next::Eof() => return Ok(output.into_boxed_slice()), }; bitpos = 8; } @@ -228,7 +254,7 @@ pub fn prs_decompress(source: &[u8]) -> Box<[u8]> { if flag { output.push(match reader.next() { Next::Byte(byte) => byte, - Next::Eof() => return output.into_boxed_slice(), + Next::Eof() => return Ok(output.into_boxed_slice()), }); continue; } @@ -237,7 +263,7 @@ pub fn prs_decompress(source: &[u8]) -> Box<[u8]> { if bitpos == 0 { current_byte = match reader.next() { Next::Byte(byte) => byte, - Next::Eof() => return output.into_boxed_slice(), + Next::Eof() => return Ok(output.into_boxed_slice()), }; bitpos = 8; } @@ -247,22 +273,22 @@ pub fn prs_decompress(source: &[u8]) -> Box<[u8]> { if flag { r3 = match reader.next() { Next::Byte(byte) => byte as i32, - Next::Eof() => return output.into_boxed_slice(), + Next::Eof() => return Ok(output.into_boxed_slice()), }; let high_byte = match reader.next() { Next::Byte(byte) => byte as i32, - Next::Eof() => return output.into_boxed_slice(), + Next::Eof() => return Ok(output.into_boxed_slice()), }; offset = ((high_byte & 0xff) << 8) | (r3 & 0xff); if offset == 0 { - return output.into_boxed_slice(); + return Ok(output.into_boxed_slice()); } r3 &= 0x00000007; r5 = (offset >> 3) | -8192i32; // 0xffffe000 if r3 == 0 { r3 = match reader.next() { Next::Byte(byte) => byte as i32, - Next::Eof() => return output.into_boxed_slice(), + Next::Eof() => return Ok(output.into_boxed_slice()), }; r3 = (r3 & 0xff) + 1; } else { @@ -275,7 +301,7 @@ pub fn prs_decompress(source: &[u8]) -> Box<[u8]> { if bitpos == 0 { current_byte = match reader.next() { Next::Byte(byte) => byte, - Next::Eof() => return output.into_boxed_slice(), + Next::Eof() => return Ok(output.into_boxed_slice()), }; bitpos = 8; } @@ -286,7 +312,7 @@ pub fn prs_decompress(source: &[u8]) -> Box<[u8]> { } offset = match reader.next() { Next::Byte(byte) => byte as i32, - Next::Eof() => return output.into_boxed_slice(), + Next::Eof() => return Ok(output.into_boxed_slice()), }; r3 += 2; r5 = offset | -256i32; // 0xffffff00 @@ -296,13 +322,25 @@ pub fn prs_decompress(source: &[u8]) -> Box<[u8]> { } for _ in 0..r3 { let index = output.len() as i32 + r5; - output.push(output[index as usize]); + output.push(match output.get(index as usize) { + Some(value) => *value, + None => { + return Err(PrsCompressionError::BadData(format!( + "tried to push copy of byte at out-of-bounds index {}", + index + ))) + } + }); } } } #[cfg(test)] mod tests { + use claim::*; + use rand::rngs::StdRng; + use rand::{Fill, SeedableRng}; + use super::*; struct TestData<'a> { @@ -652,13 +690,33 @@ I do not like green eggs and ham." ]; #[test] - pub fn compresses_things() { + pub fn compresses_things() -> Result<(), PrsCompressionError> { for (index, test) in TEST_DATA.iter().enumerate() { println!("\ntest #{}", index); println!(" prs_compress({:02x?})", test.uncompressed); - assert_eq!(*test.compressed, *prs_compress(&test.uncompressed)); + assert_eq!(*test.compressed, *prs_compress(&test.uncompressed)?); println!(" prs_decompress({:02x?})", test.compressed); - assert_eq!(*test.uncompressed, *prs_decompress(&test.compressed)); + assert_eq!(*test.uncompressed, *prs_decompress(&test.compressed)?); } + Ok(()) + } + + #[test] + pub fn decompress_bad_data_error_result() -> Result<(), PrsCompressionError> { + let data: &[u8] = &[]; + assert_matches!(prs_decompress(data), Err(PrsCompressionError::BadData(..))); + + let data: &[u8] = &[1, 2]; + assert_matches!(prs_decompress(data), Err(PrsCompressionError::BadData(..))); + + let data: &[u8] = &[1, 2, 3]; + assert_matches!(prs_decompress(data), Err(PrsCompressionError::BadData(..))); + + let mut data = [0u8; 1024]; + let mut rng = StdRng::seed_from_u64(42); + data.try_fill(&mut rng).unwrap(); + assert_matches!(prs_decompress(&data), Err(PrsCompressionError::BadData(..))); + + Ok(()) } } diff --git a/psoutils/src/quest/bin.rs b/psoutils/src/quest/bin.rs index 0424e6c..52d15d1 100644 --- a/psoutils/src/quest/bin.rs +++ b/psoutils/src/quest/bin.rs @@ -6,7 +6,7 @@ use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use thiserror::Error; use crate::bytes::*; -use crate::compression::{prs_compress, prs_decompress}; +use crate::compression::{prs_compress, prs_decompress, PrsCompressionError}; use crate::text::Language; pub const QUEST_BIN_NAME_LENGTH: usize = 32; @@ -23,6 +23,9 @@ pub enum QuestBinError { #[error("I/O error while processing quest bin")] IoError(#[from] std::io::Error), + #[error("PRS compression failed")] + PrsCompressionError(#[from] PrsCompressionError), + #[error("Bad quest bin data format: {0}")] DataFormatError(String), } @@ -75,7 +78,7 @@ pub struct QuestBin { impl QuestBin { pub fn from_compressed_bytes(bytes: &[u8]) -> Result { - let decompressed = prs_decompress(&bytes); + let decompressed = prs_decompress(&bytes)?; let mut reader = Cursor::new(decompressed); Ok(QuestBin::from_uncompressed_bytes(&mut reader)?) } @@ -283,7 +286,7 @@ impl QuestBin { pub fn to_compressed_bytes(&self) -> Result, QuestBinError> { let uncompressed = self.to_uncompressed_bytes()?; - Ok(prs_compress(uncompressed.as_ref())) + Ok(prs_compress(uncompressed.as_ref())?) } pub fn calculate_size(&self) -> usize { diff --git a/psoutils/src/quest/dat.rs b/psoutils/src/quest/dat.rs index a3233c8..ec583d5 100644 --- a/psoutils/src/quest/dat.rs +++ b/psoutils/src/quest/dat.rs @@ -6,7 +6,7 @@ use std::path::Path; use byteorder::{LittleEndian, ReadBytesExt, WriteBytesExt}; use thiserror::Error; -use crate::compression::{prs_compress, prs_decompress}; +use crate::compression::{prs_compress, prs_decompress, PrsCompressionError}; pub const QUEST_DAT_TABLE_HEADER_SIZE: usize = 16; @@ -58,6 +58,9 @@ pub enum QuestDatError { #[error("I/O error while processing quest dat")] IoError(#[from] std::io::Error), + #[error("PRS compression failed")] + PrsCompressionError(#[from] PrsCompressionError), + #[error("Bad quest dat data format: {0}")] DataFormatError(String), } @@ -163,7 +166,7 @@ pub struct QuestDat { impl QuestDat { pub fn from_compressed_bytes(bytes: &[u8]) -> Result { - let decompressed = prs_decompress(&bytes); + let decompressed = prs_decompress(&bytes)?; let mut reader = Cursor::new(decompressed); Ok(QuestDat::from_uncompressed_bytes(&mut reader)?) } @@ -276,7 +279,7 @@ impl QuestDat { pub fn to_compressed_bytes(&self) -> Result, QuestDatError> { let uncompressed = self.to_uncompressed_bytes()?; - Ok(prs_compress(uncompressed.as_ref())) + Ok(prs_compress(uncompressed.as_ref())?) } pub fn calculate_size(&self) -> usize {