summaryrefslogtreecommitdiffstats
path: root/src/base64_decode.rs
blob: de8147cd190c6db08069c2fed4ad48bff1c52324 (plain)
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
/// An error that can occur during base64 decoding.
#[derive(Debug, Clone, Eq, PartialEq)]
pub enum Error {
	InvalidBase64Char(u8),
}

impl std::fmt::Display for Error {
	fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
		match self {
			Self::InvalidBase64Char(value) => write!(f, "Invalid base64 character: {:?}", char::from_u32(*value as u32).unwrap()),
		}
	}
}

/// Decode a base64 string.
///
/// Padding in the input is optional.
pub fn base64_decode(input: &[u8]) -> Result<Vec<u8>, Error> {
	let input = match input.iter().rposition(|&byte| byte != b'=' && !byte.is_ascii_whitespace()) {
		Some(x) => &input[..=x],
		None => return Ok(Vec::new()),
	};

	let mut output = Vec::with_capacity((input.len() + 3) / 4 * 3);
	let mut decoder = Base64Decoder::new();

	for &byte in input {
		if byte.is_ascii_whitespace() {
			continue;
		}
		if let Some(byte) = decoder.feed(byte)? {
			output.push(byte);
		}
	}

	Ok(output)
}

/// Get the 6 bit value for a base64 character.
fn base64_value(byte: u8) -> Result<u8, Error> {
	match byte {
		b'A'..=b'Z' => Ok(byte - b'A'),
		b'a'..=b'z' => Ok(byte - b'a' + 26),
		b'0'..=b'9' => Ok(byte - b'0' + 52),
		b'+' => Ok(62),
		b'/' => Ok(63),
		byte => Err(Error::InvalidBase64Char(byte)),
	}
}

/// Decoder for base64 data.
struct Base64Decoder {
	/// The current buffer.
	buffer: u16,

	/// The number of valid bits in the buffer.
	valid_bits: u8,
}

impl Base64Decoder {
	/// Create a new base64 decoder.
	fn new() -> Self {
		Self {
			buffer: 0,
			valid_bits: 0,
		}
	}

	/// Feed a base64 character to the decoder.
	///
	/// Returns `Ok(Some(u8))` if a new character is fully decoded.
	/// Returns `Ok(None)` if there is no new character available yet.
	fn feed(&mut self, byte: u8) -> Result<Option<u8>, Error> {
		debug_assert!(self.valid_bits < 8);
		// Paste the new 6 bit value at the least significant position in the buffer.
		self.buffer |= (base64_value(byte)? as u16) << (10 - self.valid_bits);
		// Bump the number of valid bits.
		self.valid_bits += 6;
		// Consume the most significant byte if it is complete.
		Ok(self.consume_buffer_front())
	}

	/// Consume the first character in the buffer.
	fn consume_buffer_front(&mut self) -> Option<u8> {
		if self.valid_bits >= 8 {
			let value = self.buffer >> 8 & 0xFF;
			self.buffer <<= 8;
			self.valid_bits -= 8;
			Some(value as u8)
		} else {
			None
		}
	}
}

#[cfg(test)]
mod test {
	use super::*;
	use assert2::assert;

	#[test]
	fn test_decode_base64() {
		assert!(let Ok(b"0") = base64_decode(b"MA").as_deref());
		assert!(let Ok(b"0") = base64_decode(b"MA=").as_deref());
		assert!(let Ok(b"0") = base64_decode(b"MA==").as_deref());
		assert!(let Ok(b"aap noot mies") = base64_decode(b"YWFwIG5vb3QgbWllcw").as_deref());
		assert!(let Ok(b"aap noot mies") = base64_decode(b"YWFwIG5vb3QgbWllcw=").as_deref());
		assert!(let Ok(b"aap noot mies") = base64_decode(b"YWFwIG5vb3QgbWllcw==").as_deref());
	}
}