diff --git a/chapter1.md b/chapter1.md index e1d7c7c..2980f48 100644 --- a/chapter1.md +++ b/chapter1.md @@ -323,43 +323,43 @@ we'll use a `struct` called `BytePacketBuffer`. ```rust pub struct BytePacketBuffer { pub buf: [u8; 512], - pub pos: usize + pub pos: usize, } impl BytePacketBuffer { - // This gives us a fresh buffer for holding the packet contents, and a field for - // keeping track of where we are. + /// This gives us a fresh buffer for holding the packet contents, and a + /// field for keeping track of where we are. pub fn new() -> BytePacketBuffer { BytePacketBuffer { buf: [0; 512], - pos: 0 + pos: 0, } } - // When handling the reading of domain names, we'll need a way of - // reading and manipulating our buffer position. - + /// Current position within buffer fn pos(&self) -> usize { self.pos } + /// Step the buffer position forward a specific number of steps fn step(&mut self, steps: usize) -> Result<()> { self.pos += steps; Ok(()) } + /// Change the buffer position fn seek(&mut self, pos: usize) -> Result<()> { self.pos = pos; Ok(()) } - // A method for reading a single byte, and moving one step forward + /// Read a single byte and move the position one step forward fn read(&mut self) -> Result { if self.pos >= 512 { - return Err(Error::new(ErrorKind::InvalidInput, "End of buffer")); + return Err("End of buffer".into()); } let res = self.buf[self.pos]; self.pos += 1; @@ -367,49 +367,46 @@ impl BytePacketBuffer { Ok(res) } - // Methods for fetching data at a specified position, without modifying - // the internal position - + /// Get a single byte, without changing the buffer position fn get(&mut self, pos: usize) -> Result { if pos >= 512 { - return Err(Error::new(ErrorKind::InvalidInput, "End of buffer")); + return Err("End of buffer".into()); } Ok(self.buf[pos]) } + /// Get a range of bytes fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> { if start + len >= 512 { - return Err(Error::new(ErrorKind::InvalidInput, "End of buffer")); + return Err("End of buffer".into()); } - Ok(&self.buf[start..start+len as usize]) + Ok(&self.buf[start..start + len as usize]) } - // Methods for reading a u16 and u32 from the buffer, while stepping - // forward 2 or 4 bytes - - fn read_u16(&mut self) -> Result - { - let res = ((try!(self.read()) as u16) << 8) | - (try!(self.read()) as u16); + /// Read two bytes, stepping two steps forward + fn read_u16(&mut self) -> Result { + let res = ((self.read()? as u16) << 8) | (self.read()? as u16); Ok(res) } - fn read_u32(&mut self) -> Result - { - let res = ((try!(self.read()) as u32) << 24) | - ((try!(self.read()) as u32) << 16) | - ((try!(self.read()) as u32) << 8) | - ((try!(self.read()) as u32) << 0); + /// Read four bytes, stepping four steps forward + fn read_u32(&mut self) -> Result { + let res = ((self.read()? as u32) << 24) + | ((self.read()? as u32) << 16) + | ((self.read()? as u32) << 8) + | ((self.read()? as u32) << 0); Ok(res) } - // The tricky part: Reading domain names, taking labels into consideration. - // Will take something like [3]www[6]google[3]com[0] and append - // www.google.com to outstr. - fn read_qname(&mut self, outstr: &mut String) -> Result<()> - { + + /// Read a qname + /// + /// The tricky part: Reading domain names, taking labels into consideration. + /// Will take something like [3]www[6]google[3]com[0] and append + /// www.google.com to outstr. + fn read_qname(&mut self, outstr: &mut String) -> Result<()> { // Since we might encounter jumps, we'll keep track of our position // locally as opposed to using the position within the struct. This // allows us to move the shared position to a point past our current @@ -419,43 +416,54 @@ impl BytePacketBuffer { // track whether or not we've jumped let mut jumped = false; + let max_jumps = 5; + let mut jumps_performed = 0; - // Our delimiter which we append for each label. Since we don't want a dot at the - // beginning of the domain name we'll leave it empty for now and set it to "." at - // the end of the first iteration. + // Our delimiter which we append for each label. Since we don't want a + // dot at the beginning of the domain name we'll leave it empty for now + // and set it to "." at the end of the first iteration. let mut delim = ""; loop { + // Dns Packets are untrusted data, so we need to be paranoid. Someone + // can craft a packet with a cycle in the jump instructions. This guards + // against such packets. + if jumps_performed > max_jumps { + return Err(format!("Limit of {} jumps exceeded", max_jumps).into()); + } + // At this point, we're always at the beginning of a label. Recall // that labels start with a length byte. - let len = try!(self.get(pos)); + let len = self.get(pos)?; - // If len has the two most significant bit are set, it represents a jump to - // some other offset in the packet: + // If len has the two most significant bit are set, it represents a + // jump to some other offset in the packet: if (len & 0xC0) == 0xC0 { // Update the buffer position to a point past the current // label. We don't need to touch it any further. if !jumped { - try!(self.seek(pos+2)); + self.seek(pos + 2)?; } // Read another byte, calculate offset and perform the jump by // updating our local position variable - let b2 = try!(self.get(pos+1)) as u16; + let b2 = self.get(pos + 1)? as u16; let offset = (((len as u16) ^ 0xC0) << 8) | b2; pos = offset as usize; // Indicate that a jump was performed. jumped = true; - } + jumps_performed += 1; + continue; + } // The base scenario, where we're reading a single label and // appending it to the output: else { // Move a single byte forward to move past the length byte. pos += 1; - // Domain names are terminated by an empty label of length 0, so if the length is zero - // we're done. + // Domain names are terminated by an empty label of length 0, + // so if the length is zero we're done. if len == 0 { break; } @@ -463,9 +471,9 @@ impl BytePacketBuffer { // Append the delimiter to our output buffer first. outstr.push_str(delim); - // Extract the actual ASCII bytes for this label and append them to the output buffer. - - let str_buffer = try!(self.get_range(pos, len as usize)); + // Extract the actual ASCII bytes for this label and append them + // to the output buffer. + let str_buffer = self.get_range(pos, len as usize)?; outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase()); delim = "."; @@ -475,16 +483,13 @@ impl BytePacketBuffer { } } - // If a jump has been performed, we've already modified the buffer position state and - // shouldn't do so again. if !jumped { - try!(self.seek(pos)); + self.seek(pos)?; } Ok(()) - } // End of read_qname - -} // End of BytePacketBuffer + } +} ``` ### ResultCode @@ -492,14 +497,14 @@ impl BytePacketBuffer { Before we move on to the header, we'll add an enum for the values of `rescode` field: ```rust -#[derive(Copy,Clone,Debug,PartialEq,Eq)] +#[derive(Copy, Clone, Debug, PartialEq, Eq)] pub enum ResultCode { NOERROR = 0, FORMERR = 1, SERVFAIL = 2, NXDOMAIN = 3, NOTIMP = 4, - REFUSED = 5 + REFUSED = 5, } impl ResultCode { @@ -510,7 +515,7 @@ impl ResultCode { 3 => ResultCode::NXDOMAIN, 4 => ResultCode::NOTIMP, 5 => ResultCode::REFUSED, - 0 | _ => ResultCode::NOERROR + 0 | _ => ResultCode::NOERROR, } } } @@ -521,26 +526,26 @@ impl ResultCode { Now we can get to work on the header. We'll represent it like this: ```rust -#[derive(Clone,Debug)] +#[derive(Clone, Debug)] pub struct DnsHeader { pub id: u16, // 16 bits - pub recursion_desired: bool, // 1 bit - pub truncated_message: bool, // 1 bit + pub recursion_desired: bool, // 1 bit + pub truncated_message: bool, // 1 bit pub authoritative_answer: bool, // 1 bit - pub opcode: u8, // 4 bits - pub response: bool, // 1 bit + pub opcode: u8, // 4 bits + pub response: bool, // 1 bit - pub rescode: ResultCode, // 4 bits - pub checking_disabled: bool, // 1 bit - pub authed_data: bool, // 1 bit - pub z: bool, // 1 bit + pub rescode: ResultCode, // 4 bits + pub checking_disabled: bool, // 1 bit + pub authed_data: bool, // 1 bit + pub z: bool, // 1 bit pub recursion_available: bool, // 1 bit - pub questions: u16, // 16 bits - pub answers: u16, // 16 bits + pub questions: u16, // 16 bits + pub answers: u16, // 16 bits pub authoritative_entries: u16, // 16 bits - pub resource_entries: u16 // 16 bits + pub resource_entries: u16, // 16 bits } ``` @@ -549,30 +554,32 @@ The implementation involves a lot of bit twiddling: ```rust impl DnsHeader { pub fn new() -> DnsHeader { - DnsHeader { id: 0, + DnsHeader { + id: 0, - recursion_desired: false, - truncated_message: false, - authoritative_answer: false, - opcode: 0, - response: false, + recursion_desired: false, + truncated_message: false, + authoritative_answer: false, + opcode: 0, + response: false, - rescode: ResultCode::NOERROR, - checking_disabled: false, - authed_data: false, - z: false, - recursion_available: false, + rescode: ResultCode::NOERROR, + checking_disabled: false, + authed_data: false, + z: false, + recursion_available: false, - questions: 0, - answers: 0, - authoritative_entries: 0, - resource_entries: 0 } + questions: 0, + answers: 0, + authoritative_entries: 0, + resource_entries: 0, + } } pub fn read(&mut self, buffer: &mut BytePacketBuffer) -> Result<()> { - self.id = try!(buffer.read_u16()); + self.id = buffer.read_u16()?; - let flags = try!(buffer.read_u16()); + let flags = buffer.read_u16()?; let a = (flags >> 8) as u8; let b = (flags & 0xFF) as u8; self.recursion_desired = (a & (1 << 0)) > 0; @@ -587,10 +594,10 @@ impl DnsHeader { self.z = (b & (1 << 6)) > 0; self.recursion_available = (b & (1 << 7)) > 0; - self.questions = try!(buffer.read_u16()); - self.answers = try!(buffer.read_u16()); - self.authoritative_entries = try!(buffer.read_u16()); - self.resource_entries = try!(buffer.read_u16()); + self.questions = buffer.read_u16()?; + self.answers = buffer.read_u16()?; + self.authoritative_entries = buffer.read_u16()?; + self.resource_entries = buffer.read_u16()?; // Return the constant header size Ok(()) @@ -604,7 +611,7 @@ Before moving on to the question part of the packet, we'll need a way to represent the record type being queried: ```rust -#[derive(PartialEq,Eq,Debug,Clone,Hash,Copy)] +#[derive(PartialEq, Eq, Debug, Clone, Hash, Copy)] pub enum QueryType { UNKNOWN(u16), A, // 1 @@ -621,7 +628,7 @@ impl QueryType { pub fn from_num(num: u16) -> QueryType { match num { 1 => QueryType::A, - _ => QueryType::UNKNOWN(num) + _ => QueryType::UNKNOWN(num), } } } @@ -633,24 +640,24 @@ The enum allows us to easily add more record types later on. Now for the question entries: ```rust -#[derive(Debug,Clone,PartialEq,Eq)] +#[derive(Debug, Clone, PartialEq, Eq)] pub struct DnsQuestion { pub name: String, - pub qtype: QueryType + pub qtype: QueryType, } impl DnsQuestion { pub fn new(name: String, qtype: QueryType) -> DnsQuestion { DnsQuestion { name: name, - qtype: qtype + qtype: qtype, } } pub fn read(&mut self, buffer: &mut BytePacketBuffer) -> Result<()> { - try!(buffer.read_qname(&mut self.name)); - self.qtype = QueryType::from_num(try!(buffer.read_u16())); // qtype - let _ = try!(buffer.read_u16()); // class + buffer.read_qname(&mut self.name)?; + self.qtype = QueryType::from_num(buffer.read_u16()?); // qtype + let _ = buffer.read_u16()?; // class Ok(()) } @@ -666,19 +673,19 @@ We'll obviously need a way of representing the actual dns records as well, and again we'll use an enum for easy expansion: ```rust -#[derive(Debug,Clone,PartialEq,Eq,Hash,PartialOrd,Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] #[allow(dead_code)] pub enum DnsRecord { UNKNOWN { domain: String, qtype: u16, data_len: u16, - ttl: u32 + ttl: u32, }, // 0 A { domain: String, addr: Ipv4Addr, - ttl: u32 + ttl: u32, }, // 1 } ``` @@ -690,39 +697,40 @@ this: ```rust impl DnsRecord { - pub fn read(buffer: &mut BytePacketBuffer) -> Result { let mut domain = String::new(); - try!(buffer.read_qname(&mut domain)); + buffer.read_qname(&mut domain)?; - let qtype_num = try!(buffer.read_u16()); + let qtype_num = buffer.read_u16()?; let qtype = QueryType::from_num(qtype_num); - let _ = try!(buffer.read_u16()); // class, which we ignore - let ttl = try!(buffer.read_u32()); - let data_len = try!(buffer.read_u16()); + let _ = buffer.read_u16()?; + let ttl = buffer.read_u32()?; + let data_len = buffer.read_u16()?; match qtype { - QueryType::A => { - let raw_addr = try!(buffer.read_u32()); - let addr = Ipv4Addr::new(((raw_addr >> 24) & 0xFF) as u8, - ((raw_addr >> 16) & 0xFF) as u8, - ((raw_addr >> 8) & 0xFF) as u8, - ((raw_addr >> 0) & 0xFF) as u8); + QueryType::A => { + let raw_addr = buffer.read_u32()?; + let addr = Ipv4Addr::new( + ((raw_addr >> 24) & 0xFF) as u8, + ((raw_addr >> 16) & 0xFF) as u8, + ((raw_addr >> 8) & 0xFF) as u8, + ((raw_addr >> 0) & 0xFF) as u8, + ); Ok(DnsRecord::A { domain: domain, addr: addr, - ttl: ttl + ttl: ttl, }) - }, + } QueryType::UNKNOWN(_) => { - try!(buffer.step(data_len as usize)); + buffer.step(data_len as usize)?; Ok(DnsRecord::UNKNOWN { domain: domain, qtype: qtype_num, data_len: data_len, - ttl: ttl + ttl: ttl, }) } } @@ -741,7 +749,7 @@ pub struct DnsPacket { pub questions: Vec, pub answers: Vec, pub authorities: Vec, - pub resources: Vec + pub resources: Vec, } impl DnsPacket { @@ -751,31 +759,30 @@ impl DnsPacket { questions: Vec::new(), answers: Vec::new(), authorities: Vec::new(), - resources: Vec::new() + resources: Vec::new(), } } pub fn from_buffer(buffer: &mut BytePacketBuffer) -> Result { let mut result = DnsPacket::new(); - try!(result.header.read(buffer)); + result.header.read(buffer)?; for _ in 0..result.header.questions { - let mut question = DnsQuestion::new("".to_string(), - QueryType::UNKNOWN(0)); - try!(question.read(buffer)); + let mut question = DnsQuestion::new("".to_string(), QueryType::UNKNOWN(0)); + question.read(buffer)?; result.questions.push(question); } for _ in 0..result.header.answers { - let rec = try!(DnsRecord::read(buffer)); + let rec = DnsRecord::read(buffer)?; result.answers.push(rec); } for _ in 0..result.header.authoritative_entries { - let rec = try!(DnsRecord::read(buffer)); + let rec = DnsRecord::read(buffer)?; result.authorities.push(rec); } for _ in 0..result.header.resource_entries { - let rec = try!(DnsRecord::read(buffer)); + let rec = DnsRecord::read(buffer)?; result.resources.push(rec); } @@ -789,26 +796,28 @@ impl DnsPacket { Let's use the `response_packet.txt` we generated earlier to try it out! ```rust -fn main() { - let mut f = File::open("response_packet.txt").unwrap(); +fn main() -> Result<()> { + let mut f = File::open("response_packet.txt")?; let mut buffer = BytePacketBuffer::new(); - f.read(&mut buffer.buf).unwrap(); + f.read(&mut buffer.buf)?; - let packet = DnsPacket::from_buffer(&mut buffer).unwrap(); - println!("{:?}", packet.header); + let packet = DnsPacket::from_buffer(&mut buffer)?; + println!("{:#?}", packet.header); for q in packet.questions { - println!("{:?}", q); + println!("{:#?}", q); } for rec in packet.answers { - println!("{:?}", rec); + println!("{:#?}", rec); } for rec in packet.authorities { - println!("{:?}", rec); + println!("{:#?}", rec); } for rec in packet.resources { - println!("{:?}", rec); + println!("{:#?}", rec); } + + Ok(()) } ``` diff --git a/chapter2.md b/chapter2.md index 277e9a3..0a4588e 100644 --- a/chapter2.md +++ b/chapter2.md @@ -21,7 +21,7 @@ impl BytePacketBuffer { fn write(&mut self, val: u8) -> Result<()> { if self.pos >= 512 { - return Err(Error::new(ErrorKind::InvalidInput, "End of buffer")); + return Err("End of buffer".into()); } self.buf[self.pos] = val; self.pos += 1; @@ -29,23 +29,23 @@ impl BytePacketBuffer { } fn write_u8(&mut self, val: u8) -> Result<()> { - try!(self.write(val)); + self.write(val)?; Ok(()) } fn write_u16(&mut self, val: u16) -> Result<()> { - try!(self.write((val >> 8) as u8)); - try!(self.write((val & 0xFF) as u8)); + self.write((val >> 8) as u8)?; + self.write((val & 0xFF) as u8)?; Ok(()) } fn write_u32(&mut self, val: u32) -> Result<()> { - try!(self.write(((val >> 24) & 0xFF) as u8)); - try!(self.write(((val >> 16) & 0xFF) as u8)); - try!(self.write(((val >> 8) & 0xFF) as u8)); - try!(self.write(((val >> 0) & 0xFF) as u8)); + self.write(((val >> 24) & 0xFF) as u8)?; + self.write(((val >> 16) & 0xFF) as u8)?; + self.write(((val >> 8) & 0xFF) as u8)?; + self.write(((val >> 0) & 0xFF) as u8)?; Ok(()) } @@ -55,22 +55,19 @@ We'll also need a function for writing query names in labeled form: ```rust fn write_qname(&mut self, qname: &str) -> Result<()> { - - let split_str = qname.split('.').collect::>(); - - for label in split_str { + for label in qname.split('.') { let len = label.len(); if len > 0x34 { - return Err(Error::new(ErrorKind::InvalidInput, "Single label exceeds 63 characters of length")); + return Err("Single label exceeds 63 characters of length".into()); } - try!(self.write_u8(len as u8)); + self.write_u8(len as u8)?; for b in label.as_bytes() { - try!(self.write_u8(*b)); + self.write_u8(*b)?; } } - try!(self.write_u8(0)); + self.write_u8(0)?; Ok(()) } @@ -89,24 +86,28 @@ impl DnsHeader { - snip - pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<()> { - try!(buffer.write_u16(self.id)); + buffer.write_u16(self.id)?; - try!(buffer.write_u8( ((self.recursion_desired as u8)) | - ((self.truncated_message as u8) << 1) | - ((self.authoritative_answer as u8) << 2) | - (self.opcode << 3) | - ((self.response as u8) << 7) as u8) ); + buffer.write_u8( + (self.recursion_desired as u8) + | ((self.truncated_message as u8) << 1) + | ((self.authoritative_answer as u8) << 2) + | (self.opcode << 3) + | ((self.response as u8) << 7) as u8, + )?; - try!(buffer.write_u8( (self.rescode.clone() as u8) | - ((self.checking_disabled as u8) << 4) | - ((self.authed_data as u8) << 5) | - ((self.z as u8) << 6) | - ((self.recursion_available as u8) << 7) )); + buffer.write_u8( + (self.rescode.clone() as u8) + | ((self.checking_disabled as u8) << 4) + | ((self.authed_data as u8) << 5) + | ((self.z as u8) << 6) + | ((self.recursion_available as u8) << 7), + )?; - try!(buffer.write_u16(self.questions)); - try!(buffer.write_u16(self.answers)); - try!(buffer.write_u16(self.authoritative_entries)); - try!(buffer.write_u16(self.resource_entries)); + buffer.write_u16(self.questions)?; + buffer.write_u16(self.answers)?; + buffer.write_u16(self.authoritative_entries)?; + buffer.write_u16(self.resource_entries)?; Ok(()) } @@ -124,12 +125,11 @@ impl DnsQuestion { - snip - pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<()> { - - try!(buffer.write_qname(&self.name)); + buffer.write_qname(&self.name)?; let typenum = self.qtype.to_num(); - try!(buffer.write_u16(typenum)); - try!(buffer.write_u16(1)); + buffer.write_u16(typenum)?; + buffer.write_u16(1)?; Ok(()) } @@ -148,23 +148,26 @@ impl DnsRecord { - snip - pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result { - let start_pos = buffer.pos(); match *self { - DnsRecord::A { ref domain, ref addr, ttl } => { - try!(buffer.write_qname(domain)); - try!(buffer.write_u16(QueryType::A.to_num())); - try!(buffer.write_u16(1)); - try!(buffer.write_u32(ttl)); - try!(buffer.write_u16(4)); + DnsRecord::A { + ref domain, + ref addr, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::A.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + buffer.write_u16(4)?; let octets = addr.octets(); - try!(buffer.write_u8(octets[0])); - try!(buffer.write_u8(octets[1])); - try!(buffer.write_u8(octets[2])); - try!(buffer.write_u8(octets[3])); - }, + buffer.write_u8(octets[0])?; + buffer.write_u8(octets[1])?; + buffer.write_u8(octets[2])?; + buffer.write_u8(octets[3])?; + } DnsRecord::UNKNOWN { .. } => { println!("Skipping record: {:?}", self); } @@ -185,26 +188,25 @@ impl DnsPacket { - snip - - pub fn write(&mut self, buffer: &mut BytePacketBuffer) -> Result<()> - { + pub fn write(&mut self, buffer: &mut BytePacketBuffer) -> Result<()> { self.header.questions = self.questions.len() as u16; self.header.answers = self.answers.len() as u16; self.header.authoritative_entries = self.authorities.len() as u16; self.header.resource_entries = self.resources.len() as u16; - try!(self.header.write(buffer)); + self.header.write(buffer)?; for question in &self.questions { - try!(question.write(buffer)); + question.write(buffer)?; } for rec in &self.answers { - try!(rec.write(buffer)); + rec.write(buffer)?; } for rec in &self.authorities { - try!(rec.write(buffer)); + rec.write(buffer)?; } for rec in &self.resources { - try!(rec.write(buffer)); + rec.write(buffer)?; } Ok(()) @@ -219,7 +221,7 @@ We're ready to implement our stub resolver. Rust includes a convenient `UDPSocket` which does most of the work. ```rust -fn main() { +fn main() -> Result<()> { // Perform an A query for google.com let qname = "google.com"; let qtype = QueryType::A; @@ -228,7 +230,7 @@ fn main() { let server = ("8.8.8.8", 53); // Bind a UDP socket to an arbitrary port - let socket = UdpSocket::bind(("0.0.0.0", 43210)).unwrap(); + let socket = UdpSocket::bind(("0.0.0.0", 43210))?; // Build our query packet. It's important that we remember to set the // `recursion_desired` flag. As noted earlier, the packet id is arbitrary. @@ -237,37 +239,41 @@ fn main() { packet.header.id = 6666; packet.header.questions = 1; packet.header.recursion_desired = true; - packet.questions.push(DnsQuestion::new(qname.to_string(), qtype)); + packet + .questions + .push(DnsQuestion::new(qname.to_string(), qtype)); // Use our new write method to write the packet to a buffer... let mut req_buffer = BytePacketBuffer::new(); - packet.write(&mut req_buffer).unwrap(); + packet.write(&mut req_buffer)?; // ...and send it off to the server using our socket: - socket.send_to(&req_buffer.buf[0..req_buffer.pos], server).unwrap(); + socket.send_to(&req_buffer.buf[0..req_buffer.pos], server)?; // To prepare for receiving the response, we'll create a new `BytePacketBuffer`, // and ask the socket to write the response directly into our buffer. let mut res_buffer = BytePacketBuffer::new(); - socket.recv_from(&mut res_buffer.buf).unwrap(); + socket.recv_from(&mut res_buffer.buf)?; // As per the previous section, `DnsPacket::from_buffer()` is then used to // actually parse the packet after which we can print the response. - let res_packet = DnsPacket::from_buffer(&mut res_buffer).unwrap(); - println!("{:?}", res_packet.header); + let res_packet = DnsPacket::from_buffer(&mut res_buffer)?; + println!("{:#?}", res_packet.header); for q in res_packet.questions { - println!("{:?}", q); + println!("{:#?}", q); } for rec in res_packet.answers { - println!("{:?}", rec); + println!("{:#?}", rec); } for rec in res_packet.authorities { - println!("{:?}", rec); + println!("{:#?}", rec); } for rec in res_packet.resources { - println!("{:?}", rec); + println!("{:#?}", rec); } + + Ok(()) } ``` diff --git a/chapter3.md b/chapter3.md index c9e0275..3021a99 100644 --- a/chapter3.md +++ b/chapter3.md @@ -68,14 +68,14 @@ Let's go ahead and add them to our code! First we'll update our `QueryType` enum: ```rust -#[derive(PartialEq,Eq,Debug,Clone,Hash,Copy)] +#[derive(PartialEq, Eq, Debug, Clone, Hash, Copy)] pub enum QueryType { UNKNOWN(u16), - A, // 1 - NS, // 2 + A, // 1 + NS, // 2 CNAME, // 5 - MX, // 15 - AAAA, // 28 + MX, // 15 + AAAA, // 28 } ``` @@ -101,7 +101,7 @@ impl QueryType { 5 => QueryType::CNAME, 15 => QueryType::MX, 28 => QueryType::AAAA, - _ => QueryType::UNKNOWN(num) + _ => QueryType::UNKNOWN(num), } } } @@ -113,40 +113,40 @@ Now we need a way of holding the data for these records, so we'll make some modifications to `DnsRecord`. ```rust -#[derive(Debug,Clone,PartialEq,Eq,Hash,PartialOrd,Ord)] +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] #[allow(dead_code)] pub enum DnsRecord { UNKNOWN { domain: String, qtype: u16, data_len: u16, - ttl: u32 + ttl: u32, }, // 0 A { domain: String, addr: Ipv4Addr, - ttl: u32 + ttl: u32, }, // 1 NS { domain: String, host: String, - ttl: u32 + ttl: u32, }, // 2 CNAME { domain: String, host: String, - ttl: u32 + ttl: u32, }, // 5 MX { domain: String, priority: u16, host: String, - ttl: u32 + ttl: u32, }, // 15 AAAA { domain: String, addr: Ipv6Addr, - ttl: u32 + ttl: u32, }, // 28 } ``` @@ -156,106 +156,101 @@ and reading records. Starting with read, we amend it with additional code for each record type. First off, we've got the common preamble: ```rust -pub fn read(buffer: &mut BytePacketBuffer) -> Result { - let mut domain = String::new(); - try!(buffer.read_qname(&mut domain)); +impl DnsRecord { + pub fn read(buffer: &mut BytePacketBuffer) -> Result { + let mut domain = String::new(); + buffer.read_qname(&mut domain)?; - let qtype_num = try!(buffer.read_u16()); - let qtype = QueryType::from_num(qtype_num); - let _ = try!(buffer.read_u16()); - let ttl = try!(buffer.read_u32()); - let data_len = try!(buffer.read_u16()); + let qtype_num = buffer.read_u16()?; + let qtype = QueryType::from_num(qtype_num); + let _ = buffer.read_u16()?; + let ttl = buffer.read_u32()?; + let data_len = buffer.read_u16()?; - match qtype { + match qtype { + QueryType::A => { + let raw_addr = buffer.read_u32()?; + let addr = Ipv4Addr::new( + ((raw_addr >> 24) & 0xFF) as u8, + ((raw_addr >> 16) & 0xFF) as u8, + ((raw_addr >> 8) & 0xFF) as u8, + ((raw_addr >> 0) & 0xFF) as u8, + ); - // Handle each record type separately, starting with the A record - // type which remains the same as before. - QueryType::A => { - let raw_addr = try!(buffer.read_u32()); - let addr = Ipv4Addr::new(((raw_addr >> 24) & 0xFF) as u8, - ((raw_addr >> 16) & 0xFF) as u8, - ((raw_addr >> 8) & 0xFF) as u8, - ((raw_addr >> 0) & 0xFF) as u8); + Ok(DnsRecord::A { + domain: domain, + addr: addr, + ttl: ttl, + }) + } + QueryType::AAAA => { + let raw_addr1 = buffer.read_u32()?; + let raw_addr2 = buffer.read_u32()?; + let raw_addr3 = buffer.read_u32()?; + let raw_addr4 = buffer.read_u32()?; + let addr = Ipv6Addr::new( + ((raw_addr1 >> 16) & 0xFFFF) as u16, + ((raw_addr1 >> 0) & 0xFFFF) as u16, + ((raw_addr2 >> 16) & 0xFFFF) as u16, + ((raw_addr2 >> 0) & 0xFFFF) as u16, + ((raw_addr3 >> 16) & 0xFFFF) as u16, + ((raw_addr3 >> 0) & 0xFFFF) as u16, + ((raw_addr4 >> 16) & 0xFFFF) as u16, + ((raw_addr4 >> 0) & 0xFFFF) as u16, + ); - Ok(DnsRecord::A { - domain: domain, - addr: addr, - ttl: ttl - }) - }, + Ok(DnsRecord::AAAA { + domain: domain, + addr: addr, + ttl: ttl, + }) + } + QueryType::NS => { + let mut ns = String::new(); + buffer.read_qname(&mut ns)?; - // The AAAA record type follows the same logic, but with more numbers to keep - // track off. - QueryType::AAAA => { - let raw_addr1 = try!(buffer.read_u32()); - let raw_addr2 = try!(buffer.read_u32()); - let raw_addr3 = try!(buffer.read_u32()); - let raw_addr4 = try!(buffer.read_u32()); - let addr = Ipv6Addr::new(((raw_addr1 >> 16) & 0xFFFF) as u16, - ((raw_addr1 >> 0) & 0xFFFF) as u16, - ((raw_addr2 >> 16) & 0xFFFF) as u16, - ((raw_addr2 >> 0) & 0xFFFF) as u16, - ((raw_addr3 >> 16) & 0xFFFF) as u16, - ((raw_addr3 >> 0) & 0xFFFF) as u16, - ((raw_addr4 >> 16) & 0xFFFF) as u16, - ((raw_addr4 >> 0) & 0xFFFF) as u16); + Ok(DnsRecord::NS { + domain: domain, + host: ns, + ttl: ttl, + }) + } + QueryType::CNAME => { + let mut cname = String::new(); + buffer.read_qname(&mut cname)?; - Ok(DnsRecord::AAAA { - domain: domain, - addr: addr, - ttl: ttl - }) - }, + Ok(DnsRecord::CNAME { + domain: domain, + host: cname, + ttl: ttl, + }) + } + QueryType::MX => { + let priority = buffer.read_u16()?; + let mut mx = String::new(); + buffer.read_qname(&mut mx)?; - // NS and CNAME both have the same structure. - QueryType::NS => { - let mut ns = String::new(); - try!(buffer.read_qname(&mut ns)); + Ok(DnsRecord::MX { + domain: domain, + priority: priority, + host: mx, + ttl: ttl, + }) + } + QueryType::UNKNOWN(_) => { + buffer.step(data_len as usize)?; - Ok(DnsRecord::NS { - domain: domain, - host: ns, - ttl: ttl - }) - }, - - QueryType::CNAME => { - let mut cname = String::new(); - try!(buffer.read_qname(&mut cname)); - - Ok(DnsRecord::CNAME { - domain: domain, - host: cname, - ttl: ttl - }) - }, - - // MX is almost like the previous two, but with one extra field for priority. - QueryType::MX => { - let priority = try!(buffer.read_u16()); - let mut mx = String::new(); - try!(buffer.read_qname(&mut mx)); - - Ok(DnsRecord::MX { - domain: domain, - priority: priority, - host: mx, - ttl: ttl - }) - }, - - // And we end with some code for handling unknown record types, as before. - QueryType::UNKNOWN(_) => { - try!(buffer.step(data_len as usize)); - - Ok(DnsRecord::UNKNOWN { - domain: domain, - qtype: qtype_num, - data_len: data_len, - ttl: ttl - }) + Ok(DnsRecord::UNKNOWN { + domain: domain, + qtype: qtype_num, + data_len: data_len, + ttl: ttl, + }) + } } } + + - snip - } ``` @@ -280,8 +275,8 @@ impl BytePacketBuffer { } fn set_u16(&mut self, pos: usize, val: u16) -> Result<()> { - try!(self.set(pos,(val >> 8) as u8)); - try!(self.set(pos+1,(val & 0xFF) as u8)); + self.set(pos, (val >> 8) as u8)?; + self.set(pos + 1, (val & 0xFF) as u8)?; Ok(()) } @@ -289,89 +284,119 @@ impl BytePacketBuffer { } ``` +When writing the labels of a record, we don't know ahead of time the number of +bytes needed, since we might end up using jumps to compress the size. We'll +solve this by writing a zero size and then going back to fill in the size +needed. + ### Extending DnsRecord for writing new record types Now we can amend `DnsRecord::write`. Here's our new function: ```rust -pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result { +impl DnsRecord { - let start_pos = buffer.pos(); + - snip - - match *self { - DnsRecord::A { ref domain, ref addr, ttl } => { - try!(buffer.write_qname(domain)); - try!(buffer.write_u16(QueryType::A.to_num())); - try!(buffer.write_u16(1)); - try!(buffer.write_u32(ttl)); - try!(buffer.write_u16(4)); + pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result { + let start_pos = buffer.pos(); - let octets = addr.octets(); - try!(buffer.write_u8(octets[0])); - try!(buffer.write_u8(octets[1])); - try!(buffer.write_u8(octets[2])); - try!(buffer.write_u8(octets[3])); - }, - DnsRecord::NS { ref domain, ref host, ttl } => { - try!(buffer.write_qname(domain)); - try!(buffer.write_u16(QueryType::NS.to_num())); - try!(buffer.write_u16(1)); - try!(buffer.write_u32(ttl)); + match *self { + DnsRecord::A { + ref domain, + ref addr, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::A.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + buffer.write_u16(4)?; - let pos = buffer.pos(); - try!(buffer.write_u16(0)); - - try!(buffer.write_qname(host)); - - let size = buffer.pos() - (pos + 2); - try!(buffer.set_u16(pos, size as u16)); - }, - DnsRecord::CNAME { ref domain, ref host, ttl } => { - try!(buffer.write_qname(domain)); - try!(buffer.write_u16(QueryType::CNAME.to_num())); - try!(buffer.write_u16(1)); - try!(buffer.write_u32(ttl)); - - let pos = buffer.pos(); - try!(buffer.write_u16(0)); - - try!(buffer.write_qname(host)); - - let size = buffer.pos() - (pos + 2); - try!(buffer.set_u16(pos, size as u16)); - }, - DnsRecord::MX { ref domain, priority, ref host, ttl } => { - try!(buffer.write_qname(domain)); - try!(buffer.write_u16(QueryType::MX.to_num())); - try!(buffer.write_u16(1)); - try!(buffer.write_u32(ttl)); - - let pos = buffer.pos(); - try!(buffer.write_u16(0)); - - try!(buffer.write_u16(priority)); - try!(buffer.write_qname(host)); - - let size = buffer.pos() - (pos + 2); - try!(buffer.set_u16(pos, size as u16)); - }, - DnsRecord::AAAA { ref domain, ref addr, ttl } => { - try!(buffer.write_qname(domain)); - try!(buffer.write_u16(QueryType::AAAA.to_num())); - try!(buffer.write_u16(1)); - try!(buffer.write_u32(ttl)); - try!(buffer.write_u16(16)); - - for octet in &addr.segments() { - try!(buffer.write_u16(*octet)); + let octets = addr.octets(); + buffer.write_u8(octets[0])?; + buffer.write_u8(octets[1])?; + buffer.write_u8(octets[2])?; + buffer.write_u8(octets[3])?; } - }, - DnsRecord::UNKNOWN { .. } => { - println!("Skipping record: {:?}", self); - } - } + DnsRecord::NS { + ref domain, + ref host, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::NS.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; - Ok(buffer.pos() - start_pos) + let pos = buffer.pos(); + buffer.write_u16(0)?; + + buffer.write_qname(host)?; + + let size = buffer.pos() - (pos + 2); + buffer.set_u16(pos, size as u16)?; + } + DnsRecord::CNAME { + ref domain, + ref host, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::CNAME.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + + let pos = buffer.pos(); + buffer.write_u16(0)?; + + buffer.write_qname(host)?; + + let size = buffer.pos() - (pos + 2); + buffer.set_u16(pos, size as u16)?; + } + DnsRecord::MX { + ref domain, + priority, + ref host, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::MX.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + + let pos = buffer.pos(); + buffer.write_u16(0)?; + + buffer.write_u16(priority)?; + buffer.write_qname(host)?; + + let size = buffer.pos() - (pos + 2); + buffer.set_u16(pos, size as u16)?; + } + DnsRecord::AAAA { + ref domain, + ref addr, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::AAAA.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + buffer.write_u16(16)?; + + for octet in &addr.segments() { + buffer.write_u16(*octet)?; + } + } + DnsRecord::UNKNOWN { .. } => { + println!("Skipping record: {:?}", self); + } + } + + Ok(buffer.pos() - start_pos) + } } ``` diff --git a/chapter4.md b/chapter4.md index 5f80837..388b266 100644 --- a/chapter4.md +++ b/chapter4.md @@ -141,26 +141,27 @@ work, it's a rather quick effort! We'll start out by doing some quick refactoring, moving our lookup code into a separate function. This is for the most part the same code as we had in our -`main` function in the previous chapter, with the only change being that we -handle errors gracefully using `try!`. +`main` function in the previous chapter. ```rust fn lookup(qname: &str, qtype: QueryType, server: (&str, u16)) -> Result { - let socket = try!(UdpSocket::bind(("0.0.0.0", 43210))); + let socket = UdpSocket::bind(("0.0.0.0", 43210))?; let mut packet = DnsPacket::new(); packet.header.id = 6666; packet.header.questions = 1; packet.header.recursion_desired = true; - packet.questions.push(DnsQuestion::new(qname.to_string(), qtype)); + packet + .questions + .push(DnsQuestion::new(qname.to_string(), qtype)); let mut req_buffer = BytePacketBuffer::new(); - packet.write(&mut req_buffer).unwrap(); - try!(socket.send_to(&req_buffer.buf[0..req_buffer.pos], server)); + packet.write(&mut req_buffer)?; + socket.send_to(&req_buffer.buf[0..req_buffer.pos], server)?; let mut res_buffer = BytePacketBuffer::new(); - socket.recv_from(&mut res_buffer.buf).unwrap(); + socket.recv_from(&mut res_buffer.buf)?; DnsPacket::from_buffer(&mut res_buffer) } @@ -171,12 +172,12 @@ fn lookup(qname: &str, qtype: QueryType, server: (&str, u16)) -> Result Result<()> { // Forward queries to Google's public DNS let server = ("8.8.8.8", 53); // Bind an UDP socket on port 2053 - let socket = UdpSocket::bind(("0.0.0.0", 2053)).unwrap(); + let socket = UdpSocket::bind(("0.0.0.0", 2053))?; // For now, queries are handled sequentially, so an infinite loop for servicing // requests is initiated. @@ -224,7 +225,6 @@ fn main() { if request.questions.is_empty() { packet.header.rescode = ResultCode::FORMERR; } - // Usually a question will be present, though. else { let question = &request.questions[0]; @@ -254,37 +254,36 @@ fn main() { } else { packet.header.rescode = ResultCode::SERVFAIL; } - - // The only thing remaining is to encode our response and send it off! - - let mut res_buffer = BytePacketBuffer::new(); - match packet.write(&mut res_buffer) { - Ok(_) => {}, - Err(e) => { - println!("Failed to encode UDP response packet: {:?}", e); - continue; - } - }; - - let len = res_buffer.pos(); - let data = match res_buffer.get_range(0, len) { - Ok(x) => x, - Err(e) => { - println!("Failed to retrieve response buffer: {:?}", e); - continue; - } - }; - - match socket.send_to(data, src) { - Ok(_) => {}, - Err(e) => { - println!("Failed to send response buffer: {:?}", e); - continue; - } - }; } - } // End of request loop -} // End of main + + // The only thing remaining is to encode our response and send it off! + let mut res_buffer = BytePacketBuffer::new(); + match packet.write(&mut res_buffer) { + Ok(_) => {} + Err(e) => { + println!("Failed to encode UDP response packet: {:?}", e); + continue; + } + }; + + let len = res_buffer.pos(); + let data = match res_buffer.get_range(0, len) { + Ok(x) => x, + Err(e) => { + println!("Failed to retrieve response buffer: {:?}", e); + continue; + } + }; + + match socket.send_to(data, src) { + Ok(_) => {} + Err(e) => { + println!("Failed to send response buffer: {:?}", e); + continue; + } + }; + } +} ``` The match idiom for error handling is used again and again here, since we want to avoid diff --git a/chapter5.md b/chapter5.md index 6107f69..7023cf4 100644 --- a/chapter5.md +++ b/chapter5.md @@ -168,89 +168,68 @@ impl DnsPacket { - snip - - // It's useful to be able to pick a random A record from a packet. When we - // get multiple IP's for a single name, it doesn't matter which one we - // choose, so in those cases we can now pick one at random. + /// It's useful to be able to pick a random A record from a packet. When we + /// get multiple IP's for a single name, it doesn't matter which one we + /// choose, so in those cases we can now pick one at random. pub fn get_random_a(&self) -> Option { - if !self.answers.is_empty() { - let idx = random::() % self.answers.len(); - let a_record = &self.answers[idx]; - if let DnsRecord::A{ ref addr, .. } = *a_record { - return Some(addr.to_string()); - } - } - - None + self.answers + .iter() + .filter_map(|record| match record { + DnsRecord::A { ref addr, .. } => Some(addr.to_string()), + _ => None, + }) + .next() } - // We'll use the fact that name servers often bundle the corresponding - // A records when replying to an NS query to implement a function that returns - // the actual IP for an NS record if possible. + /// A helper function which returns an iterator over all name servers in + /// the authorities section, represented as (domain, host) tuples + fn get_ns<'a>(&'a self, qname: &'a str) -> impl Iterator { + self.authorities.iter() + // In practice, these are always NS records in well formed packages. + // Convert the NS records to a tuple which has only the data we need + // to make it easy to work with. + .filter_map(|record| match record { + DnsRecord::NS { domain, host, .. } => Some((domain.as_str(), host.as_str())), + _ => None, + }) + // Discard servers which aren't authoritative to our query + .filter(move |(domain, _)| qname.ends_with(*domain)) + } + + /// When there is a NS record in the authorities section, there may also + /// be a matching A record in the additional section. This saves us + /// from doing a separate query to resolve the IP of the name server. pub fn get_resolved_ns(&self, qname: &str) -> Option { + // Get an iterator over the nameservers in the authorities section + self.get_ns(qname) + // Now we need to look for a matching A record in the additional + // section. Since we just want the first valid record, we can just + // build a stream of matching records. + .flat_map(|(_, host)| { + self.resources.iter() + // Filter for A records where the domain match the host + // of the NS record that we are currently processing + .filter_map(move |record| match record { + DnsRecord::A { domain, addr, .. } if domain == host => Some(addr), + _ => None, + }) + }) + .map(|addr| addr.to_string()) + // Finally, pick the first valid entry + .next() + } - // First, we scan the list of NS records in the authorities section: - let mut new_authorities = Vec::new(); - for auth in &self.authorities { - if let DnsRecord::NS { ref domain, ref host, .. } = *auth { - if !qname.ends_with(domain) { - continue; - } - - // Once we've found an NS record, we scan the resources record for a matching - // A record... - for rsrc in &self.resources { - if let DnsRecord::A{ ref domain, ref addr, ttl } = *rsrc { - if domain != host { - continue; - } - - let rec = DnsRecord::A { - domain: host.clone(), - addr: *addr, - ttl: ttl - }; - - // ...and push any matches to a list. - new_authorities.push(rec); - } - } - } - } - - // If there are any matches, we pick the first one. - if !new_authorities.is_empty() { - if let DnsRecord::A { addr, .. } = new_authorities[0] { - return Some(addr.to_string()); - } - } - - None - } // End of get_resolved_ns - - // However, not all name servers are as that nice. In certain cases there won't - // be any A records in the additional section, and we'll have to perform *another* - // lookup in the midst. For this, we introduce a method for returning the host - // name of an appropriate name server. + /// However, not all name servers are as that nice. In certain cases there won't + /// be any A records in the additional section, and we'll have to perform *another* + /// lookup in the midst of our first. For this, we introduce a method for + returning the hostname of an appropriate name server. pub fn get_unresolved_ns(&self, qname: &str) -> Option { - - let mut new_authorities = Vec::new(); - for auth in &self.authorities { - if let DnsRecord::NS { ref domain, ref host, .. } = *auth { - if !qname.ends_with(domain) { - continue; - } - - new_authorities.push(host); - } - } - - if !new_authorities.is_empty() { - let idx = random::() % new_authorities.len(); - return Some(new_authorities[idx].clone()); - } - - None - } // End of get_unresolved_ns + // Get an iterator over the nameservers in the authorities section + self.get_ns(qname) + .map(|(_, host)| host.to_string()) + // Finally, pick the first valid entry + .next() + } } // End of DnsPacket ``` @@ -273,12 +252,10 @@ fn recursive_lookup(qname: &str, qtype: QueryType) -> Result { let ns_copy = ns.clone(); let server = (ns_copy.as_str(), 53); - let response = try!(lookup(qname, qtype.clone(), server)); + let response = lookup(qname, qtype.clone(), server)?; // If there are entries in the answer section, and no errors, we are done! - if !response.answers.is_empty() && - response.header.rescode == ResultCode::NOERROR { - + if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR { return Ok(response.clone()); } @@ -301,23 +278,23 @@ fn recursive_lookup(qname: &str, qtype: QueryType) -> Result { // we'll go with what the last server told us. let new_ns_name = match response.get_unresolved_ns(qname) { Some(x) => x, - None => return Ok(response.clone()) + None => return Ok(response.clone()), }; // Here we go down the rabbit hole by starting _another_ lookup sequence in the // midst of our current one. Hopefully, this will give us the IP of an appropriate // name server. - let recursive_response = try!(recursive_lookup(&new_ns_name, QueryType::A)); + let recursive_response = recursive_lookup(&new_ns_name, QueryType::A)?; // Finally, we pick a random ip from the result, and restart the loop. If no such // record is available, we again return the last result we got. if let Some(new_ns) = recursive_response.get_random_a() { ns = new_ns.clone(); } else { - return Ok(response.clone()) + return Ok(response.clone()); } } -} // End of recursive_lookup +} ``` ### Trying out recursive lookup @@ -326,7 +303,7 @@ The only thing remaining is to change our main function to use `recursive_lookup`: ```rust -fn main() { +fn main() -> Result<()> { - snip - diff --git a/examples/sample1.rs b/examples/sample1.rs index c677d2a..c367fcb 100644 --- a/examples/sample1.rs +++ b/examples/sample1.rs @@ -11,6 +11,8 @@ pub struct BytePacketBuffer { } impl BytePacketBuffer { + /// This gives us a fresh buffer for holding the packet contents, and a + /// field for keeping track of where we are. pub fn new() -> BytePacketBuffer { BytePacketBuffer { buf: [0; 512], @@ -18,22 +20,26 @@ impl BytePacketBuffer { } } + /// Current position within buffer fn pos(&self) -> usize { self.pos } + /// Step the buffer position forward a specific number of steps fn step(&mut self, steps: usize) -> Result<()> { self.pos += steps; Ok(()) } + /// Change the buffer position fn seek(&mut self, pos: usize) -> Result<()> { self.pos = pos; Ok(()) } + /// Read a single byte and move the position one step forward fn read(&mut self) -> Result { if self.pos >= 512 { return Err("End of buffer".into()); @@ -44,6 +50,7 @@ impl BytePacketBuffer { Ok(res) } + /// Get a single byte, without changing the buffer position fn get(&mut self, pos: usize) -> Result { if pos >= 512 { return Err("End of buffer".into()); @@ -51,6 +58,7 @@ impl BytePacketBuffer { Ok(self.buf[pos]) } + /// Get a range of bytes fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> { if start + len >= 512 { return Err("End of buffer".into()); @@ -58,12 +66,14 @@ impl BytePacketBuffer { Ok(&self.buf[start..start + len as usize]) } + /// Read two bytes, stepping two steps forward fn read_u16(&mut self) -> Result { let res = ((self.read()? as u16) << 8) | (self.read()? as u16); Ok(res) } + /// Read four bytes, stepping four steps forward fn read_u32(&mut self) -> Result { let res = ((self.read()? as u32) << 24) | ((self.read()? as u32) << 16) @@ -73,13 +83,28 @@ impl BytePacketBuffer { Ok(res) } + /// Read a qname + /// + /// The tricky part: Reading domain names, taking labels into consideration. + /// Will take something like [3]www[6]google[3]com[0] and append + /// www.google.com to outstr. fn read_qname(&mut self, outstr: &mut String) -> Result<()> { + // Since we might encounter jumps, we'll keep track of our position + // locally as opposed to using the position within the struct. This + // allows us to move the shared position to a point past our current + // qname, while keeping track of our progress on the current qname + // using this variable. let mut pos = self.pos(); - let mut jumped = false; - let mut delim = ""; + // track whether or not we've jumped + let mut jumped = false; let max_jumps = 5; let mut jumps_performed = 0; + + // Our delimiter which we append for each label. Since we don't want a + // dot at the beginning of the domain name we'll leave it empty for now + // and set it to "." at the end of the first iteration. + let mut delim = ""; loop { // Dns Packets are untrusted data, so we need to be paranoid. Someone // can craft a packet with a cycle in the jump instructions. This guards @@ -88,42 +113,56 @@ impl BytePacketBuffer { return Err(format!("Limit of {} jumps exceeded", max_jumps).into()); } + // At this point, we're always at the beginning of a label. Recall + // that labels start with a length byte. let len = self.get(pos)?; - // A two byte sequence, where the two highest bits of the first byte is - // set, represents a offset relative to the start of the buffer. We - // handle this by jumping to the offset, setting a flag to indicate - // that we shouldn't update the shared buffer position once done. + // If len has the two most significant bit are set, it represents a + // jump to some other offset in the packet: if (len & 0xC0) == 0xC0 { - // When a jump is performed, we only modify the shared buffer - // position once, and avoid making the change later on. + // Update the buffer position to a point past the current + // label. We don't need to touch it any further. if !jumped { self.seek(pos + 2)?; } + // Read another byte, calculate offset and perform the jump by + // updating our local position variable let b2 = self.get(pos + 1)? as u16; let offset = (((len as u16) ^ 0xC0) << 8) | b2; pos = offset as usize; + + // Indicate that a jump was performed. jumped = true; jumps_performed += 1; + continue; } + // The base scenario, where we're reading a single label and + // appending it to the output: + else { + // Move a single byte forward to move past the length byte. + pos += 1; - pos += 1; + // Domain names are terminated by an empty label of length 0, + // so if the length is zero we're done. + if len == 0 { + break; + } - // Names are terminated by an empty label of length 0 - if len == 0 { - break; + // Append the delimiter to our output buffer first. + outstr.push_str(delim); + + // Extract the actual ASCII bytes for this label and append them + // to the output buffer. + let str_buffer = self.get_range(pos, len as usize)?; + outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase()); + + delim = "."; + + // Move forward the full length of the label. + pos += len as usize; } - - outstr.push_str(delim); - - let str_buffer = self.get_range(pos, len as usize)?; - outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase()); - - delim = "."; - - pos += len as usize; } if !jumped { @@ -386,19 +425,19 @@ fn main() -> Result<()> { f.read(&mut buffer.buf)?; let packet = DnsPacket::from_buffer(&mut buffer)?; - println!("{:?}", packet.header); + println!("{:#?}", packet.header); for q in packet.questions { - println!("{:?}", q); + println!("{:#?}", q); } for rec in packet.answers { - println!("{:?}", rec); + println!("{:#?}", rec); } for rec in packet.authorities { - println!("{:?}", rec); + println!("{:#?}", rec); } for rec in packet.resources { - println!("{:?}", rec); + println!("{:#?}", rec); } Ok(()) diff --git a/examples/sample2.rs b/examples/sample2.rs index 9b31908..cf0dba4 100644 --- a/examples/sample2.rs +++ b/examples/sample2.rs @@ -164,9 +164,7 @@ impl BytePacketBuffer { } fn write_qname(&mut self, qname: &str) -> Result<()> { - let split_str = qname.split('.').collect::>(); - - for label in split_str { + for label in qname.split('.') { let len = label.len(); if len > 0x34 { return Err("Single label exceeds 63 characters of length".into()); @@ -521,12 +519,18 @@ impl DnsPacket { } fn main() -> Result<()> { - let qname = "www.yahoo.com"; + // Perform an A query for google.com + let qname = "google.com"; let qtype = QueryType::A; + + // Using googles public DNS server let server = ("8.8.8.8", 53); + // Bind a UDP socket to an arbitrary port let socket = UdpSocket::bind(("0.0.0.0", 43210))?; + // Build our query packet. It's important that we remember to set the + // `recursion_desired` flag. As noted earlier, the packet id is arbitrary. let mut packet = DnsPacket::new(); packet.header.id = 6666; @@ -536,27 +540,34 @@ fn main() -> Result<()> { .questions .push(DnsQuestion::new(qname.to_string(), qtype)); + // Use our new write method to write the packet to a buffer... let mut req_buffer = BytePacketBuffer::new(); packet.write(&mut req_buffer)?; + + // ...and send it off to the server using our socket: socket.send_to(&req_buffer.buf[0..req_buffer.pos], server)?; + // To prepare for receiving the response, we'll create a new `BytePacketBuffer`, + // and ask the socket to write the response directly into our buffer. let mut res_buffer = BytePacketBuffer::new(); socket.recv_from(&mut res_buffer.buf)?; + // As per the previous section, `DnsPacket::from_buffer()` is then used to + // actually parse the packet after which we can print the response. let res_packet = DnsPacket::from_buffer(&mut res_buffer)?; - println!("{:?}", res_packet.header); + println!("{:#?}", res_packet.header); for q in res_packet.questions { - println!("{:?}", q); + println!("{:#?}", q); } for rec in res_packet.answers { - println!("{:?}", rec); + println!("{:#?}", rec); } for rec in res_packet.authorities { - println!("{:?}", rec); + println!("{:#?}", rec); } for rec in res_packet.resources { - println!("{:?}", rec); + println!("{:#?}", rec); } Ok(()) diff --git a/examples/sample3.rs b/examples/sample3.rs index 2052475..e995562 100644 --- a/examples/sample3.rs +++ b/examples/sample3.rs @@ -164,9 +164,7 @@ impl BytePacketBuffer { } fn write_qname(&mut self, qname: &str) -> Result<()> { - let split_str = qname.split('.').collect::>(); - - for label in split_str { + for label in qname.split('.') { let len = label.len(); if len > 0x34 { return Err("Single label exceeds 63 characters of length".into()); diff --git a/examples/sample4.rs b/examples/sample4.rs index d85ba9c..0f928cd 100644 --- a/examples/sample4.rs +++ b/examples/sample4.rs @@ -164,9 +164,7 @@ impl BytePacketBuffer { } fn write_qname(&mut self, qname: &str) -> Result<()> { - let split_str = qname.split('.').collect::>(); - - for label in split_str { + for label in qname.split('.') { let len = label.len(); if len > 0x34 { return Err("Single label exceeds 63 characters of length".into()); @@ -714,11 +712,17 @@ fn lookup(qname: &str, qtype: QueryType, server: (&str, u16)) -> Result Result<()> { + // Forward queries to Google's public DNS let server = ("8.8.8.8", 53); + // Bind an UDP socket on port 2053 let socket = UdpSocket::bind(("0.0.0.0", 2053))?; + // For now, queries are handled sequentially, so an infinite loop for servicing + // requests is initiated. loop { + // With a socket ready, we can go ahead and read a packet. This will + // block until one is received. let mut req_buffer = BytePacketBuffer::new(); let (_, src) = match socket.recv_from(&mut req_buffer.buf) { Ok(x) => x, @@ -728,6 +732,16 @@ fn main() -> Result<()> { } }; + // Here we use match to safely unwrap the `Result`. If everything's as expected, + // the raw bytes are simply returned, and if not it'll abort by restarting the + // loop and waiting for the next request. The `recv_from` function will write the + // data into the provided buffer, and return the length of the data read as well + // as the source address. We're not interested in the length, but we need to keep + // track of the source in order to send our reply later on. + + // Next, `DnsPacket::from_buffer` is used to parse the raw bytes into + // a `DnsPacket`. It uses the same error handling idiom as the previous statement. + let request = match DnsPacket::from_buffer(&mut req_buffer) { Ok(x) => x, Err(e) => { @@ -736,18 +750,29 @@ fn main() -> Result<()> { } }; + // Create and initialize the response packet let mut packet = DnsPacket::new(); packet.header.id = request.header.id; packet.header.recursion_desired = true; packet.header.recursion_available = true; packet.header.response = true; + // Being mindful of how unreliable input data from arbitrary senders can be, we + // need make sure that a question is actually present. If not, we return `FORMERR` + // to indicate that the sender made something wrong. if request.questions.is_empty() { packet.header.rescode = ResultCode::FORMERR; - } else { + } + // Usually a question will be present, though. + else { let question = &request.questions[0]; println!("Received query: {:?}", question); + // Since all is set up and as expected, the query can be forwarded to the target + // server. There's always the possibility that the query will fail, in which case + // the `SERVFAIL` response code is set to indicate as much to the client. If + // rather everything goes as planned, the question and response records as copied + // into our response packet. if let Ok(result) = lookup(&question.name, question.qtype, server) { packet.questions.push(question.clone()); packet.header.rescode = result.header.rescode; @@ -769,6 +794,7 @@ fn main() -> Result<()> { } } + // The only thing remaining is to encode our response and send it off! let mut res_buffer = BytePacketBuffer::new(); match packet.write(&mut res_buffer) { Ok(_) => {} diff --git a/examples/sample5.rs b/examples/sample5.rs index 191fdd6..354d4bd 100644 --- a/examples/sample5.rs +++ b/examples/sample5.rs @@ -164,9 +164,7 @@ impl BytePacketBuffer { } fn write_qname(&mut self, qname: &str) -> Result<()> { - let split_str = qname.split('.').collect::>(); - - for label in split_str { + for label in qname.split('.') { let len = label.len(); if len > 0x34 { return Err("Single label exceeds 63 characters of length".into()); @@ -690,84 +688,69 @@ impl DnsPacket { Ok(()) } + /// It's useful to be able to pick a random A record from a packet. When we + /// get multiple IP's for a single name, it doesn't matter which one we + /// choose, so in those cases we can now pick one at random. pub fn get_random_a(&self) -> Option { - if !self.answers.is_empty() { - let a_record = &self.answers[0]; - if let DnsRecord::A { ref addr, .. } = *a_record { - return Some(addr.to_string()); - } - } - - None + self.answers + .iter() + .filter_map(|record| match record { + DnsRecord::A { ref addr, .. } => Some(addr.to_string()), + _ => None, + }) + .next() } + /// A helper function which returns an iterator over all name servers in + /// the authorities section, represented as (domain, host) tuples + fn get_ns<'a>(&'a self, qname: &'a str) -> impl Iterator { + self.authorities + .iter() + // In practice, these are always NS records in well formed packages. + // Convert the NS records to a tuple which has only the data we need + // to make it easy to work with. + .filter_map(|record| match record { + DnsRecord::NS { domain, host, .. } => Some((domain.as_str(), host.as_str())), + _ => None, + }) + // Discard servers which aren't authoritative to our query + .filter(move |(domain, _)| qname.ends_with(*domain)) + } + + /// We'll use the fact that name servers often bundle the corresponding + /// A records when replying to an NS query to implement a function that + /// returns the actual IP for an NS record if possible. pub fn get_resolved_ns(&self, qname: &str) -> Option { - let mut new_authorities = Vec::new(); - for auth in &self.authorities { - if let DnsRecord::NS { - ref domain, - ref host, - .. - } = *auth - { - if !qname.ends_with(domain) { - continue; - } - - for rsrc in &self.resources { - if let DnsRecord::A { - ref domain, - ref addr, - ttl, - } = *rsrc - { - if domain != host { - continue; - } - - let rec = DnsRecord::A { - domain: host.clone(), - addr: *addr, - ttl: ttl, - }; - - new_authorities.push(rec); - } - } - } - } - - if !new_authorities.is_empty() { - if let DnsRecord::A { addr, .. } = new_authorities[0] { - return Some(addr.to_string()); - } - } - - None + // Get an iterator over the nameservers in the authorities section + self.get_ns(qname) + // Now we need to look for a matching A record in the additional + // section. Since we just want the first valid record, we can just + // build a stream of matching records. + .flat_map(|(_, host)| { + self.resources + .iter() + // Filter for A records where the domain match the host + // of the NS record that we are currently processing + .filter_map(move |record| match record { + DnsRecord::A { domain, addr, .. } if domain == host => Some(addr), + _ => None, + }) + }) + .map(|addr| addr.to_string()) + // Finally, pick the first valid entry + .next() } + /// However, not all name servers are as that nice. In certain cases there won't + /// be any A records in the additional section, and we'll have to perform *another* + /// lookup in the midst. For this, we introduce a method for returning the host + /// name of an appropriate name server. pub fn get_unresolved_ns(&self, qname: &str) -> Option { - let mut new_authorities = Vec::new(); - for auth in &self.authorities { - if let DnsRecord::NS { - ref domain, - ref host, - .. - } = *auth - { - if !qname.ends_with(domain) { - continue; - } - - new_authorities.push(host); - } - } - - if !new_authorities.is_empty() { - return Some(new_authorities[0].clone()); - } - - None + // Get an iterator over the nameservers in the authorities section + self.get_ns(qname) + .map(|(_, host)| host.to_string()) + // Finally, pick the first valid entry + .next() } } @@ -794,45 +777,53 @@ fn lookup(qname: &str, qtype: QueryType, server: (&str, u16)) -> Result Result { + // For now we're always starting with *a.root-servers.net*. let mut ns = "198.41.0.4".to_string(); - // Start querying name servers + // Since it might take an arbitrary number of steps, we enter an unbounded loop. loop { println!("attempting lookup of {:?} {} with ns {}", qtype, qname, ns); + // The next step is to send the query to the active server. let ns_copy = ns.clone(); let server = (ns_copy.as_str(), 53); let response = lookup(qname, qtype.clone(), server)?; - // If we've got an actual answer, we're done! + // If there are entries in the answer section, and no errors, we are done! if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR { return Ok(response.clone()); } + // We might also get a `NXDOMAIN` reply, which is the authoritative name servers + // way of telling us that the name doesn't exist. if response.header.rescode == ResultCode::NXDOMAIN { return Ok(response.clone()); } - // Otherwise, try to find a new nameserver based on NS and a - // corresponding A record in the additional section + // Otherwise, we'll try to find a new nameserver based on NS and a corresponding A + // record in the additional section. If this succeeds, we can switch name server + // and retry the loop. if let Some(new_ns) = response.get_resolved_ns(qname) { - // If there is such a record, we can retry the loop with that NS ns = new_ns.clone(); continue; } - // If not, we'll have to resolve the ip of a NS record + // If not, we'll have to resolve the ip of a NS record. If no NS records exist, + // we'll go with what the last server told us. let new_ns_name = match response.get_unresolved_ns(qname) { Some(x) => x, None => return Ok(response.clone()), }; - // Recursively resolve the NS + // Here we go down the rabbit hole by starting _another_ lookup sequence in the + // midst of our current one. Hopefully, this will give us the IP of an appropriate + // name server. let recursive_response = recursive_lookup(&new_ns_name, QueryType::A)?; - // Pick a random IP and restart + // Finally, we pick a random ip from the result, and restart the loop. If no such + // record is available, we again return the last result we got. if let Some(new_ns) = recursive_response.get_random_a() { ns = new_ns.clone(); } else {