Update examples with modernized code

This commit is contained in:
Emil Hernvall 2020-06-18 01:47:09 +02:00
parent 31369696d9
commit f815075ae4
10 changed files with 708 additions and 627 deletions

View File

@ -323,43 +323,43 @@ we'll use a `struct` called `BytePacketBuffer`.
```rust ```rust
pub struct BytePacketBuffer { pub struct BytePacketBuffer {
pub buf: [u8; 512], pub buf: [u8; 512],
pub pos: usize pub pos: usize,
} }
impl BytePacketBuffer { impl BytePacketBuffer {
// This gives us a fresh buffer for holding the packet contents, and a field for /// This gives us a fresh buffer for holding the packet contents, and a
// keeping track of where we are. /// field for keeping track of where we are.
pub fn new() -> BytePacketBuffer { pub fn new() -> BytePacketBuffer {
BytePacketBuffer { BytePacketBuffer {
buf: [0; 512], buf: [0; 512],
pos: 0 pos: 0,
} }
} }
// When handling the reading of domain names, we'll need a way of /// Current position within buffer
// reading and manipulating our buffer position.
fn pos(&self) -> usize { fn pos(&self) -> usize {
self.pos self.pos
} }
/// Step the buffer position forward a specific number of steps
fn step(&mut self, steps: usize) -> Result<()> { fn step(&mut self, steps: usize) -> Result<()> {
self.pos += steps; self.pos += steps;
Ok(()) Ok(())
} }
/// Change the buffer position
fn seek(&mut self, pos: usize) -> Result<()> { fn seek(&mut self, pos: usize) -> Result<()> {
self.pos = pos; self.pos = pos;
Ok(()) Ok(())
} }
// A method for reading a single byte, and moving one step forward /// Read a single byte and move the position one step forward
fn read(&mut self) -> Result<u8> { fn read(&mut self) -> Result<u8> {
if self.pos >= 512 { if self.pos >= 512 {
return Err(Error::new(ErrorKind::InvalidInput, "End of buffer")); return Err("End of buffer".into());
} }
let res = self.buf[self.pos]; let res = self.buf[self.pos];
self.pos += 1; self.pos += 1;
@ -367,49 +367,46 @@ impl BytePacketBuffer {
Ok(res) Ok(res)
} }
// Methods for fetching data at a specified position, without modifying /// Get a single byte, without changing the buffer position
// the internal position
fn get(&mut self, pos: usize) -> Result<u8> { fn get(&mut self, pos: usize) -> Result<u8> {
if pos >= 512 { if pos >= 512 {
return Err(Error::new(ErrorKind::InvalidInput, "End of buffer")); return Err("End of buffer".into());
} }
Ok(self.buf[pos]) Ok(self.buf[pos])
} }
/// Get a range of bytes
fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> { fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
if start + len >= 512 { if start + len >= 512 {
return Err(Error::new(ErrorKind::InvalidInput, "End of buffer")); return Err("End of buffer".into());
} }
Ok(&self.buf[start..start+len as usize]) Ok(&self.buf[start..start + len as usize])
} }
// Methods for reading a u16 and u32 from the buffer, while stepping /// Read two bytes, stepping two steps forward
// forward 2 or 4 bytes fn read_u16(&mut self) -> Result<u16> {
let res = ((self.read()? as u16) << 8) | (self.read()? as u16);
fn read_u16(&mut self) -> Result<u16>
{
let res = ((try!(self.read()) as u16) << 8) |
(try!(self.read()) as u16);
Ok(res) Ok(res)
} }
fn read_u32(&mut self) -> Result<u32> /// Read four bytes, stepping four steps forward
{ fn read_u32(&mut self) -> Result<u32> {
let res = ((try!(self.read()) as u32) << 24) | let res = ((self.read()? as u32) << 24)
((try!(self.read()) as u32) << 16) | | ((self.read()? as u32) << 16)
((try!(self.read()) as u32) << 8) | | ((self.read()? as u32) << 8)
((try!(self.read()) as u32) << 0); | ((self.read()? as u32) << 0);
Ok(res) Ok(res)
} }
// The tricky part: Reading domain names, taking labels into consideration.
// Will take something like [3]www[6]google[3]com[0] and append /// Read a qname
// www.google.com to outstr. ///
fn read_qname(&mut self, outstr: &mut String) -> Result<()> /// The tricky part: Reading domain names, taking labels into consideration.
{ /// Will take something like [3]www[6]google[3]com[0] and append
/// www.google.com to outstr.
fn read_qname(&mut self, outstr: &mut String) -> Result<()> {
// Since we might encounter jumps, we'll keep track of our position // Since we might encounter jumps, we'll keep track of our position
// locally as opposed to using the position within the struct. This // locally as opposed to using the position within the struct. This
// allows us to move the shared position to a point past our current // allows us to move the shared position to a point past our current
@ -419,43 +416,54 @@ impl BytePacketBuffer {
// track whether or not we've jumped // track whether or not we've jumped
let mut jumped = false; let mut jumped = false;
let max_jumps = 5;
let mut jumps_performed = 0;
// Our delimiter which we append for each label. Since we don't want a dot at the // Our delimiter which we append for each label. Since we don't want a
// beginning of the domain name we'll leave it empty for now and set it to "." at // dot at the beginning of the domain name we'll leave it empty for now
// the end of the first iteration. // and set it to "." at the end of the first iteration.
let mut delim = ""; let mut delim = "";
loop { loop {
// Dns Packets are untrusted data, so we need to be paranoid. Someone
// can craft a packet with a cycle in the jump instructions. This guards
// against such packets.
if jumps_performed > max_jumps {
return Err(format!("Limit of {} jumps exceeded", max_jumps).into());
}
// At this point, we're always at the beginning of a label. Recall // At this point, we're always at the beginning of a label. Recall
// that labels start with a length byte. // that labels start with a length byte.
let len = try!(self.get(pos)); let len = self.get(pos)?;
// If len has the two most significant bit are set, it represents a jump to // If len has the two most significant bit are set, it represents a
// some other offset in the packet: // jump to some other offset in the packet:
if (len & 0xC0) == 0xC0 { if (len & 0xC0) == 0xC0 {
// Update the buffer position to a point past the current // Update the buffer position to a point past the current
// label. We don't need to touch it any further. // label. We don't need to touch it any further.
if !jumped { if !jumped {
try!(self.seek(pos+2)); self.seek(pos + 2)?;
} }
// Read another byte, calculate offset and perform the jump by // Read another byte, calculate offset and perform the jump by
// updating our local position variable // updating our local position variable
let b2 = try!(self.get(pos+1)) as u16; let b2 = self.get(pos + 1)? as u16;
let offset = (((len as u16) ^ 0xC0) << 8) | b2; let offset = (((len as u16) ^ 0xC0) << 8) | b2;
pos = offset as usize; pos = offset as usize;
// Indicate that a jump was performed. // Indicate that a jump was performed.
jumped = true; jumped = true;
} jumps_performed += 1;
continue;
}
// The base scenario, where we're reading a single label and // The base scenario, where we're reading a single label and
// appending it to the output: // appending it to the output:
else { else {
// Move a single byte forward to move past the length byte. // Move a single byte forward to move past the length byte.
pos += 1; pos += 1;
// Domain names are terminated by an empty label of length 0, so if the length is zero // Domain names are terminated by an empty label of length 0,
// we're done. // so if the length is zero we're done.
if len == 0 { if len == 0 {
break; break;
} }
@ -463,9 +471,9 @@ impl BytePacketBuffer {
// Append the delimiter to our output buffer first. // Append the delimiter to our output buffer first.
outstr.push_str(delim); outstr.push_str(delim);
// Extract the actual ASCII bytes for this label and append them to the output buffer. // Extract the actual ASCII bytes for this label and append them
// to the output buffer.
let str_buffer = try!(self.get_range(pos, len as usize)); let str_buffer = self.get_range(pos, len as usize)?;
outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase()); outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase());
delim = "."; delim = ".";
@ -475,16 +483,13 @@ impl BytePacketBuffer {
} }
} }
// If a jump has been performed, we've already modified the buffer position state and
// shouldn't do so again.
if !jumped { if !jumped {
try!(self.seek(pos)); self.seek(pos)?;
} }
Ok(()) Ok(())
} // End of read_qname }
}
} // End of BytePacketBuffer
``` ```
### ResultCode ### ResultCode
@ -492,14 +497,14 @@ impl BytePacketBuffer {
Before we move on to the header, we'll add an enum for the values of `rescode` field: Before we move on to the header, we'll add an enum for the values of `rescode` field:
```rust ```rust
#[derive(Copy,Clone,Debug,PartialEq,Eq)] #[derive(Copy, Clone, Debug, PartialEq, Eq)]
pub enum ResultCode { pub enum ResultCode {
NOERROR = 0, NOERROR = 0,
FORMERR = 1, FORMERR = 1,
SERVFAIL = 2, SERVFAIL = 2,
NXDOMAIN = 3, NXDOMAIN = 3,
NOTIMP = 4, NOTIMP = 4,
REFUSED = 5 REFUSED = 5,
} }
impl ResultCode { impl ResultCode {
@ -510,7 +515,7 @@ impl ResultCode {
3 => ResultCode::NXDOMAIN, 3 => ResultCode::NXDOMAIN,
4 => ResultCode::NOTIMP, 4 => ResultCode::NOTIMP,
5 => ResultCode::REFUSED, 5 => ResultCode::REFUSED,
0 | _ => ResultCode::NOERROR 0 | _ => ResultCode::NOERROR,
} }
} }
} }
@ -521,26 +526,26 @@ impl ResultCode {
Now we can get to work on the header. We'll represent it like this: Now we can get to work on the header. We'll represent it like this:
```rust ```rust
#[derive(Clone,Debug)] #[derive(Clone, Debug)]
pub struct DnsHeader { pub struct DnsHeader {
pub id: u16, // 16 bits pub id: u16, // 16 bits
pub recursion_desired: bool, // 1 bit pub recursion_desired: bool, // 1 bit
pub truncated_message: bool, // 1 bit pub truncated_message: bool, // 1 bit
pub authoritative_answer: bool, // 1 bit pub authoritative_answer: bool, // 1 bit
pub opcode: u8, // 4 bits pub opcode: u8, // 4 bits
pub response: bool, // 1 bit pub response: bool, // 1 bit
pub rescode: ResultCode, // 4 bits pub rescode: ResultCode, // 4 bits
pub checking_disabled: bool, // 1 bit pub checking_disabled: bool, // 1 bit
pub authed_data: bool, // 1 bit pub authed_data: bool, // 1 bit
pub z: bool, // 1 bit pub z: bool, // 1 bit
pub recursion_available: bool, // 1 bit pub recursion_available: bool, // 1 bit
pub questions: u16, // 16 bits pub questions: u16, // 16 bits
pub answers: u16, // 16 bits pub answers: u16, // 16 bits
pub authoritative_entries: u16, // 16 bits pub authoritative_entries: u16, // 16 bits
pub resource_entries: u16 // 16 bits pub resource_entries: u16, // 16 bits
} }
``` ```
@ -549,30 +554,32 @@ The implementation involves a lot of bit twiddling:
```rust ```rust
impl DnsHeader { impl DnsHeader {
pub fn new() -> DnsHeader { pub fn new() -> DnsHeader {
DnsHeader { id: 0, DnsHeader {
id: 0,
recursion_desired: false, recursion_desired: false,
truncated_message: false, truncated_message: false,
authoritative_answer: false, authoritative_answer: false,
opcode: 0, opcode: 0,
response: false, response: false,
rescode: ResultCode::NOERROR, rescode: ResultCode::NOERROR,
checking_disabled: false, checking_disabled: false,
authed_data: false, authed_data: false,
z: false, z: false,
recursion_available: false, recursion_available: false,
questions: 0, questions: 0,
answers: 0, answers: 0,
authoritative_entries: 0, authoritative_entries: 0,
resource_entries: 0 } resource_entries: 0,
}
} }
pub fn read(&mut self, buffer: &mut BytePacketBuffer) -> Result<()> { pub fn read(&mut self, buffer: &mut BytePacketBuffer) -> Result<()> {
self.id = try!(buffer.read_u16()); self.id = buffer.read_u16()?;
let flags = try!(buffer.read_u16()); let flags = buffer.read_u16()?;
let a = (flags >> 8) as u8; let a = (flags >> 8) as u8;
let b = (flags & 0xFF) as u8; let b = (flags & 0xFF) as u8;
self.recursion_desired = (a & (1 << 0)) > 0; self.recursion_desired = (a & (1 << 0)) > 0;
@ -587,10 +594,10 @@ impl DnsHeader {
self.z = (b & (1 << 6)) > 0; self.z = (b & (1 << 6)) > 0;
self.recursion_available = (b & (1 << 7)) > 0; self.recursion_available = (b & (1 << 7)) > 0;
self.questions = try!(buffer.read_u16()); self.questions = buffer.read_u16()?;
self.answers = try!(buffer.read_u16()); self.answers = buffer.read_u16()?;
self.authoritative_entries = try!(buffer.read_u16()); self.authoritative_entries = buffer.read_u16()?;
self.resource_entries = try!(buffer.read_u16()); self.resource_entries = buffer.read_u16()?;
// Return the constant header size // Return the constant header size
Ok(()) Ok(())
@ -604,7 +611,7 @@ Before moving on to the question part of the packet, we'll need a way to
represent the record type being queried: represent the record type being queried:
```rust ```rust
#[derive(PartialEq,Eq,Debug,Clone,Hash,Copy)] #[derive(PartialEq, Eq, Debug, Clone, Hash, Copy)]
pub enum QueryType { pub enum QueryType {
UNKNOWN(u16), UNKNOWN(u16),
A, // 1 A, // 1
@ -621,7 +628,7 @@ impl QueryType {
pub fn from_num(num: u16) -> QueryType { pub fn from_num(num: u16) -> QueryType {
match num { match num {
1 => QueryType::A, 1 => QueryType::A,
_ => QueryType::UNKNOWN(num) _ => QueryType::UNKNOWN(num),
} }
} }
} }
@ -633,24 +640,24 @@ The enum allows us to easily add more record types later on. Now for the
question entries: question entries:
```rust ```rust
#[derive(Debug,Clone,PartialEq,Eq)] #[derive(Debug, Clone, PartialEq, Eq)]
pub struct DnsQuestion { pub struct DnsQuestion {
pub name: String, pub name: String,
pub qtype: QueryType pub qtype: QueryType,
} }
impl DnsQuestion { impl DnsQuestion {
pub fn new(name: String, qtype: QueryType) -> DnsQuestion { pub fn new(name: String, qtype: QueryType) -> DnsQuestion {
DnsQuestion { DnsQuestion {
name: name, name: name,
qtype: qtype qtype: qtype,
} }
} }
pub fn read(&mut self, buffer: &mut BytePacketBuffer) -> Result<()> { pub fn read(&mut self, buffer: &mut BytePacketBuffer) -> Result<()> {
try!(buffer.read_qname(&mut self.name)); buffer.read_qname(&mut self.name)?;
self.qtype = QueryType::from_num(try!(buffer.read_u16())); // qtype self.qtype = QueryType::from_num(buffer.read_u16()?); // qtype
let _ = try!(buffer.read_u16()); // class let _ = buffer.read_u16()?; // class
Ok(()) Ok(())
} }
@ -666,19 +673,19 @@ We'll obviously need a way of representing the actual dns records as well, and
again we'll use an enum for easy expansion: again we'll use an enum for easy expansion:
```rust ```rust
#[derive(Debug,Clone,PartialEq,Eq,Hash,PartialOrd,Ord)] #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[allow(dead_code)] #[allow(dead_code)]
pub enum DnsRecord { pub enum DnsRecord {
UNKNOWN { UNKNOWN {
domain: String, domain: String,
qtype: u16, qtype: u16,
data_len: u16, data_len: u16,
ttl: u32 ttl: u32,
}, // 0 }, // 0
A { A {
domain: String, domain: String,
addr: Ipv4Addr, addr: Ipv4Addr,
ttl: u32 ttl: u32,
}, // 1 }, // 1
} }
``` ```
@ -690,39 +697,40 @@ this:
```rust ```rust
impl DnsRecord { impl DnsRecord {
pub fn read(buffer: &mut BytePacketBuffer) -> Result<DnsRecord> { pub fn read(buffer: &mut BytePacketBuffer) -> Result<DnsRecord> {
let mut domain = String::new(); let mut domain = String::new();
try!(buffer.read_qname(&mut domain)); buffer.read_qname(&mut domain)?;
let qtype_num = try!(buffer.read_u16()); let qtype_num = buffer.read_u16()?;
let qtype = QueryType::from_num(qtype_num); let qtype = QueryType::from_num(qtype_num);
let _ = try!(buffer.read_u16()); // class, which we ignore let _ = buffer.read_u16()?;
let ttl = try!(buffer.read_u32()); let ttl = buffer.read_u32()?;
let data_len = try!(buffer.read_u16()); let data_len = buffer.read_u16()?;
match qtype { match qtype {
QueryType::A => { QueryType::A => {
let raw_addr = try!(buffer.read_u32()); let raw_addr = buffer.read_u32()?;
let addr = Ipv4Addr::new(((raw_addr >> 24) & 0xFF) as u8, let addr = Ipv4Addr::new(
((raw_addr >> 16) & 0xFF) as u8, ((raw_addr >> 24) & 0xFF) as u8,
((raw_addr >> 8) & 0xFF) as u8, ((raw_addr >> 16) & 0xFF) as u8,
((raw_addr >> 0) & 0xFF) as u8); ((raw_addr >> 8) & 0xFF) as u8,
((raw_addr >> 0) & 0xFF) as u8,
);
Ok(DnsRecord::A { Ok(DnsRecord::A {
domain: domain, domain: domain,
addr: addr, addr: addr,
ttl: ttl ttl: ttl,
}) })
}, }
QueryType::UNKNOWN(_) => { QueryType::UNKNOWN(_) => {
try!(buffer.step(data_len as usize)); buffer.step(data_len as usize)?;
Ok(DnsRecord::UNKNOWN { Ok(DnsRecord::UNKNOWN {
domain: domain, domain: domain,
qtype: qtype_num, qtype: qtype_num,
data_len: data_len, data_len: data_len,
ttl: ttl ttl: ttl,
}) })
} }
} }
@ -741,7 +749,7 @@ pub struct DnsPacket {
pub questions: Vec<DnsQuestion>, pub questions: Vec<DnsQuestion>,
pub answers: Vec<DnsRecord>, pub answers: Vec<DnsRecord>,
pub authorities: Vec<DnsRecord>, pub authorities: Vec<DnsRecord>,
pub resources: Vec<DnsRecord> pub resources: Vec<DnsRecord>,
} }
impl DnsPacket { impl DnsPacket {
@ -751,31 +759,30 @@ impl DnsPacket {
questions: Vec::new(), questions: Vec::new(),
answers: Vec::new(), answers: Vec::new(),
authorities: Vec::new(), authorities: Vec::new(),
resources: Vec::new() resources: Vec::new(),
} }
} }
pub fn from_buffer(buffer: &mut BytePacketBuffer) -> Result<DnsPacket> { pub fn from_buffer(buffer: &mut BytePacketBuffer) -> Result<DnsPacket> {
let mut result = DnsPacket::new(); let mut result = DnsPacket::new();
try!(result.header.read(buffer)); result.header.read(buffer)?;
for _ in 0..result.header.questions { for _ in 0..result.header.questions {
let mut question = DnsQuestion::new("".to_string(), let mut question = DnsQuestion::new("".to_string(), QueryType::UNKNOWN(0));
QueryType::UNKNOWN(0)); question.read(buffer)?;
try!(question.read(buffer));
result.questions.push(question); result.questions.push(question);
} }
for _ in 0..result.header.answers { for _ in 0..result.header.answers {
let rec = try!(DnsRecord::read(buffer)); let rec = DnsRecord::read(buffer)?;
result.answers.push(rec); result.answers.push(rec);
} }
for _ in 0..result.header.authoritative_entries { for _ in 0..result.header.authoritative_entries {
let rec = try!(DnsRecord::read(buffer)); let rec = DnsRecord::read(buffer)?;
result.authorities.push(rec); result.authorities.push(rec);
} }
for _ in 0..result.header.resource_entries { for _ in 0..result.header.resource_entries {
let rec = try!(DnsRecord::read(buffer)); let rec = DnsRecord::read(buffer)?;
result.resources.push(rec); result.resources.push(rec);
} }
@ -789,26 +796,28 @@ impl DnsPacket {
Let's use the `response_packet.txt` we generated earlier to try it out! Let's use the `response_packet.txt` we generated earlier to try it out!
```rust ```rust
fn main() { fn main() -> Result<()> {
let mut f = File::open("response_packet.txt").unwrap(); let mut f = File::open("response_packet.txt")?;
let mut buffer = BytePacketBuffer::new(); let mut buffer = BytePacketBuffer::new();
f.read(&mut buffer.buf).unwrap(); f.read(&mut buffer.buf)?;
let packet = DnsPacket::from_buffer(&mut buffer).unwrap(); let packet = DnsPacket::from_buffer(&mut buffer)?;
println!("{:?}", packet.header); println!("{:#?}", packet.header);
for q in packet.questions { for q in packet.questions {
println!("{:?}", q); println!("{:#?}", q);
} }
for rec in packet.answers { for rec in packet.answers {
println!("{:?}", rec); println!("{:#?}", rec);
} }
for rec in packet.authorities { for rec in packet.authorities {
println!("{:?}", rec); println!("{:#?}", rec);
} }
for rec in packet.resources { for rec in packet.resources {
println!("{:?}", rec); println!("{:#?}", rec);
} }
Ok(())
} }
``` ```

View File

@ -21,7 +21,7 @@ impl BytePacketBuffer {
fn write(&mut self, val: u8) -> Result<()> { fn write(&mut self, val: u8) -> Result<()> {
if self.pos >= 512 { if self.pos >= 512 {
return Err(Error::new(ErrorKind::InvalidInput, "End of buffer")); return Err("End of buffer".into());
} }
self.buf[self.pos] = val; self.buf[self.pos] = val;
self.pos += 1; self.pos += 1;
@ -29,23 +29,23 @@ impl BytePacketBuffer {
} }
fn write_u8(&mut self, val: u8) -> Result<()> { fn write_u8(&mut self, val: u8) -> Result<()> {
try!(self.write(val)); self.write(val)?;
Ok(()) Ok(())
} }
fn write_u16(&mut self, val: u16) -> Result<()> { fn write_u16(&mut self, val: u16) -> Result<()> {
try!(self.write((val >> 8) as u8)); self.write((val >> 8) as u8)?;
try!(self.write((val & 0xFF) as u8)); self.write((val & 0xFF) as u8)?;
Ok(()) Ok(())
} }
fn write_u32(&mut self, val: u32) -> Result<()> { fn write_u32(&mut self, val: u32) -> Result<()> {
try!(self.write(((val >> 24) & 0xFF) as u8)); self.write(((val >> 24) & 0xFF) as u8)?;
try!(self.write(((val >> 16) & 0xFF) as u8)); self.write(((val >> 16) & 0xFF) as u8)?;
try!(self.write(((val >> 8) & 0xFF) as u8)); self.write(((val >> 8) & 0xFF) as u8)?;
try!(self.write(((val >> 0) & 0xFF) as u8)); self.write(((val >> 0) & 0xFF) as u8)?;
Ok(()) Ok(())
} }
@ -55,22 +55,19 @@ We'll also need a function for writing query names in labeled form:
```rust ```rust
fn write_qname(&mut self, qname: &str) -> Result<()> { fn write_qname(&mut self, qname: &str) -> Result<()> {
for label in qname.split('.') {
let split_str = qname.split('.').collect::<Vec<&str>>();
for label in split_str {
let len = label.len(); let len = label.len();
if len > 0x34 { if len > 0x34 {
return Err(Error::new(ErrorKind::InvalidInput, "Single label exceeds 63 characters of length")); return Err("Single label exceeds 63 characters of length".into());
} }
try!(self.write_u8(len as u8)); self.write_u8(len as u8)?;
for b in label.as_bytes() { for b in label.as_bytes() {
try!(self.write_u8(*b)); self.write_u8(*b)?;
} }
} }
try!(self.write_u8(0)); self.write_u8(0)?;
Ok(()) Ok(())
} }
@ -89,24 +86,28 @@ impl DnsHeader {
- snip - - snip -
pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<()> { pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<()> {
try!(buffer.write_u16(self.id)); buffer.write_u16(self.id)?;
try!(buffer.write_u8( ((self.recursion_desired as u8)) | buffer.write_u8(
((self.truncated_message as u8) << 1) | (self.recursion_desired as u8)
((self.authoritative_answer as u8) << 2) | | ((self.truncated_message as u8) << 1)
(self.opcode << 3) | | ((self.authoritative_answer as u8) << 2)
((self.response as u8) << 7) as u8) ); | (self.opcode << 3)
| ((self.response as u8) << 7) as u8,
)?;
try!(buffer.write_u8( (self.rescode.clone() as u8) | buffer.write_u8(
((self.checking_disabled as u8) << 4) | (self.rescode.clone() as u8)
((self.authed_data as u8) << 5) | | ((self.checking_disabled as u8) << 4)
((self.z as u8) << 6) | | ((self.authed_data as u8) << 5)
((self.recursion_available as u8) << 7) )); | ((self.z as u8) << 6)
| ((self.recursion_available as u8) << 7),
)?;
try!(buffer.write_u16(self.questions)); buffer.write_u16(self.questions)?;
try!(buffer.write_u16(self.answers)); buffer.write_u16(self.answers)?;
try!(buffer.write_u16(self.authoritative_entries)); buffer.write_u16(self.authoritative_entries)?;
try!(buffer.write_u16(self.resource_entries)); buffer.write_u16(self.resource_entries)?;
Ok(()) Ok(())
} }
@ -124,12 +125,11 @@ impl DnsQuestion {
- snip - - snip -
pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<()> { pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<()> {
buffer.write_qname(&self.name)?;
try!(buffer.write_qname(&self.name));
let typenum = self.qtype.to_num(); let typenum = self.qtype.to_num();
try!(buffer.write_u16(typenum)); buffer.write_u16(typenum)?;
try!(buffer.write_u16(1)); buffer.write_u16(1)?;
Ok(()) Ok(())
} }
@ -148,23 +148,26 @@ impl DnsRecord {
- snip - - snip -
pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<usize> { pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<usize> {
let start_pos = buffer.pos(); let start_pos = buffer.pos();
match *self { match *self {
DnsRecord::A { ref domain, ref addr, ttl } => { DnsRecord::A {
try!(buffer.write_qname(domain)); ref domain,
try!(buffer.write_u16(QueryType::A.to_num())); ref addr,
try!(buffer.write_u16(1)); ttl,
try!(buffer.write_u32(ttl)); } => {
try!(buffer.write_u16(4)); buffer.write_qname(domain)?;
buffer.write_u16(QueryType::A.to_num())?;
buffer.write_u16(1)?;
buffer.write_u32(ttl)?;
buffer.write_u16(4)?;
let octets = addr.octets(); let octets = addr.octets();
try!(buffer.write_u8(octets[0])); buffer.write_u8(octets[0])?;
try!(buffer.write_u8(octets[1])); buffer.write_u8(octets[1])?;
try!(buffer.write_u8(octets[2])); buffer.write_u8(octets[2])?;
try!(buffer.write_u8(octets[3])); buffer.write_u8(octets[3])?;
}, }
DnsRecord::UNKNOWN { .. } => { DnsRecord::UNKNOWN { .. } => {
println!("Skipping record: {:?}", self); println!("Skipping record: {:?}", self);
} }
@ -185,26 +188,25 @@ impl DnsPacket {
- snip - - snip -
pub fn write(&mut self, buffer: &mut BytePacketBuffer) -> Result<()> pub fn write(&mut self, buffer: &mut BytePacketBuffer) -> Result<()> {
{
self.header.questions = self.questions.len() as u16; self.header.questions = self.questions.len() as u16;
self.header.answers = self.answers.len() as u16; self.header.answers = self.answers.len() as u16;
self.header.authoritative_entries = self.authorities.len() as u16; self.header.authoritative_entries = self.authorities.len() as u16;
self.header.resource_entries = self.resources.len() as u16; self.header.resource_entries = self.resources.len() as u16;
try!(self.header.write(buffer)); self.header.write(buffer)?;
for question in &self.questions { for question in &self.questions {
try!(question.write(buffer)); question.write(buffer)?;
} }
for rec in &self.answers { for rec in &self.answers {
try!(rec.write(buffer)); rec.write(buffer)?;
} }
for rec in &self.authorities { for rec in &self.authorities {
try!(rec.write(buffer)); rec.write(buffer)?;
} }
for rec in &self.resources { for rec in &self.resources {
try!(rec.write(buffer)); rec.write(buffer)?;
} }
Ok(()) Ok(())
@ -219,7 +221,7 @@ We're ready to implement our stub resolver. Rust includes a convenient
`UDPSocket` which does most of the work. `UDPSocket` which does most of the work.
```rust ```rust
fn main() { fn main() -> Result<()> {
// Perform an A query for google.com // Perform an A query for google.com
let qname = "google.com"; let qname = "google.com";
let qtype = QueryType::A; let qtype = QueryType::A;
@ -228,7 +230,7 @@ fn main() {
let server = ("8.8.8.8", 53); let server = ("8.8.8.8", 53);
// Bind a UDP socket to an arbitrary port // Bind a UDP socket to an arbitrary port
let socket = UdpSocket::bind(("0.0.0.0", 43210)).unwrap(); let socket = UdpSocket::bind(("0.0.0.0", 43210))?;
// Build our query packet. It's important that we remember to set the // Build our query packet. It's important that we remember to set the
// `recursion_desired` flag. As noted earlier, the packet id is arbitrary. // `recursion_desired` flag. As noted earlier, the packet id is arbitrary.
@ -237,37 +239,41 @@ fn main() {
packet.header.id = 6666; packet.header.id = 6666;
packet.header.questions = 1; packet.header.questions = 1;
packet.header.recursion_desired = true; packet.header.recursion_desired = true;
packet.questions.push(DnsQuestion::new(qname.to_string(), qtype)); packet
.questions
.push(DnsQuestion::new(qname.to_string(), qtype));
// Use our new write method to write the packet to a buffer... // Use our new write method to write the packet to a buffer...
let mut req_buffer = BytePacketBuffer::new(); let mut req_buffer = BytePacketBuffer::new();
packet.write(&mut req_buffer).unwrap(); packet.write(&mut req_buffer)?;
// ...and send it off to the server using our socket: // ...and send it off to the server using our socket:
socket.send_to(&req_buffer.buf[0..req_buffer.pos], server).unwrap(); socket.send_to(&req_buffer.buf[0..req_buffer.pos], server)?;
// To prepare for receiving the response, we'll create a new `BytePacketBuffer`, // To prepare for receiving the response, we'll create a new `BytePacketBuffer`,
// and ask the socket to write the response directly into our buffer. // and ask the socket to write the response directly into our buffer.
let mut res_buffer = BytePacketBuffer::new(); let mut res_buffer = BytePacketBuffer::new();
socket.recv_from(&mut res_buffer.buf).unwrap(); socket.recv_from(&mut res_buffer.buf)?;
// As per the previous section, `DnsPacket::from_buffer()` is then used to // As per the previous section, `DnsPacket::from_buffer()` is then used to
// actually parse the packet after which we can print the response. // actually parse the packet after which we can print the response.
let res_packet = DnsPacket::from_buffer(&mut res_buffer).unwrap(); let res_packet = DnsPacket::from_buffer(&mut res_buffer)?;
println!("{:?}", res_packet.header); println!("{:#?}", res_packet.header);
for q in res_packet.questions { for q in res_packet.questions {
println!("{:?}", q); println!("{:#?}", q);
} }
for rec in res_packet.answers { for rec in res_packet.answers {
println!("{:?}", rec); println!("{:#?}", rec);
} }
for rec in res_packet.authorities { for rec in res_packet.authorities {
println!("{:?}", rec); println!("{:#?}", rec);
} }
for rec in res_packet.resources { for rec in res_packet.resources {
println!("{:?}", rec); println!("{:#?}", rec);
} }
Ok(())
} }
``` ```

View File

@ -68,14 +68,14 @@ Let's go ahead and add them to our code! First we'll update our `QueryType`
enum: enum:
```rust ```rust
#[derive(PartialEq,Eq,Debug,Clone,Hash,Copy)] #[derive(PartialEq, Eq, Debug, Clone, Hash, Copy)]
pub enum QueryType { pub enum QueryType {
UNKNOWN(u16), UNKNOWN(u16),
A, // 1 A, // 1
NS, // 2 NS, // 2
CNAME, // 5 CNAME, // 5
MX, // 15 MX, // 15
AAAA, // 28 AAAA, // 28
} }
``` ```
@ -101,7 +101,7 @@ impl QueryType {
5 => QueryType::CNAME, 5 => QueryType::CNAME,
15 => QueryType::MX, 15 => QueryType::MX,
28 => QueryType::AAAA, 28 => QueryType::AAAA,
_ => QueryType::UNKNOWN(num) _ => QueryType::UNKNOWN(num),
} }
} }
} }
@ -113,40 +113,40 @@ Now we need a way of holding the data for these records, so we'll make some
modifications to `DnsRecord`. modifications to `DnsRecord`.
```rust ```rust
#[derive(Debug,Clone,PartialEq,Eq,Hash,PartialOrd,Ord)] #[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)]
#[allow(dead_code)] #[allow(dead_code)]
pub enum DnsRecord { pub enum DnsRecord {
UNKNOWN { UNKNOWN {
domain: String, domain: String,
qtype: u16, qtype: u16,
data_len: u16, data_len: u16,
ttl: u32 ttl: u32,
}, // 0 }, // 0
A { A {
domain: String, domain: String,
addr: Ipv4Addr, addr: Ipv4Addr,
ttl: u32 ttl: u32,
}, // 1 }, // 1
NS { NS {
domain: String, domain: String,
host: String, host: String,
ttl: u32 ttl: u32,
}, // 2 }, // 2
CNAME { CNAME {
domain: String, domain: String,
host: String, host: String,
ttl: u32 ttl: u32,
}, // 5 }, // 5
MX { MX {
domain: String, domain: String,
priority: u16, priority: u16,
host: String, host: String,
ttl: u32 ttl: u32,
}, // 15 }, // 15
AAAA { AAAA {
domain: String, domain: String,
addr: Ipv6Addr, addr: Ipv6Addr,
ttl: u32 ttl: u32,
}, // 28 }, // 28
} }
``` ```
@ -156,106 +156,101 @@ and reading records. Starting with read, we amend it with additional code for
each record type. First off, we've got the common preamble: each record type. First off, we've got the common preamble:
```rust ```rust
pub fn read(buffer: &mut BytePacketBuffer) -> Result<DnsRecord> { impl DnsRecord {
let mut domain = String::new(); pub fn read(buffer: &mut BytePacketBuffer) -> Result<DnsRecord> {
try!(buffer.read_qname(&mut domain)); let mut domain = String::new();
buffer.read_qname(&mut domain)?;
let qtype_num = try!(buffer.read_u16()); let qtype_num = buffer.read_u16()?;
let qtype = QueryType::from_num(qtype_num); let qtype = QueryType::from_num(qtype_num);
let _ = try!(buffer.read_u16()); let _ = buffer.read_u16()?;
let ttl = try!(buffer.read_u32()); let ttl = buffer.read_u32()?;
let data_len = try!(buffer.read_u16()); let data_len = buffer.read_u16()?;
match qtype { match qtype {
QueryType::A => {
let raw_addr = buffer.read_u32()?;
let addr = Ipv4Addr::new(
((raw_addr >> 24) & 0xFF) as u8,
((raw_addr >> 16) & 0xFF) as u8,
((raw_addr >> 8) & 0xFF) as u8,
((raw_addr >> 0) & 0xFF) as u8,
);
// Handle each record type separately, starting with the A record Ok(DnsRecord::A {
// type which remains the same as before. domain: domain,
QueryType::A => { addr: addr,
let raw_addr = try!(buffer.read_u32()); ttl: ttl,
let addr = Ipv4Addr::new(((raw_addr >> 24) & 0xFF) as u8, })
((raw_addr >> 16) & 0xFF) as u8, }
((raw_addr >> 8) & 0xFF) as u8, QueryType::AAAA => {
((raw_addr >> 0) & 0xFF) as u8); let raw_addr1 = buffer.read_u32()?;
let raw_addr2 = buffer.read_u32()?;
let raw_addr3 = buffer.read_u32()?;
let raw_addr4 = buffer.read_u32()?;
let addr = Ipv6Addr::new(
((raw_addr1 >> 16) & 0xFFFF) as u16,
((raw_addr1 >> 0) & 0xFFFF) as u16,
((raw_addr2 >> 16) & 0xFFFF) as u16,
((raw_addr2 >> 0) & 0xFFFF) as u16,
((raw_addr3 >> 16) & 0xFFFF) as u16,
((raw_addr3 >> 0) & 0xFFFF) as u16,
((raw_addr4 >> 16) & 0xFFFF) as u16,
((raw_addr4 >> 0) & 0xFFFF) as u16,
);
Ok(DnsRecord::A { Ok(DnsRecord::AAAA {
domain: domain, domain: domain,
addr: addr, addr: addr,
ttl: ttl ttl: ttl,
}) })
}, }
QueryType::NS => {
let mut ns = String::new();
buffer.read_qname(&mut ns)?;
// The AAAA record type follows the same logic, but with more numbers to keep Ok(DnsRecord::NS {
// track off. domain: domain,
QueryType::AAAA => { host: ns,
let raw_addr1 = try!(buffer.read_u32()); ttl: ttl,
let raw_addr2 = try!(buffer.read_u32()); })
let raw_addr3 = try!(buffer.read_u32()); }
let raw_addr4 = try!(buffer.read_u32()); QueryType::CNAME => {
let addr = Ipv6Addr::new(((raw_addr1 >> 16) & 0xFFFF) as u16, let mut cname = String::new();
((raw_addr1 >> 0) & 0xFFFF) as u16, buffer.read_qname(&mut cname)?;
((raw_addr2 >> 16) & 0xFFFF) as u16,
((raw_addr2 >> 0) & 0xFFFF) as u16,
((raw_addr3 >> 16) & 0xFFFF) as u16,
((raw_addr3 >> 0) & 0xFFFF) as u16,
((raw_addr4 >> 16) & 0xFFFF) as u16,
((raw_addr4 >> 0) & 0xFFFF) as u16);
Ok(DnsRecord::AAAA { Ok(DnsRecord::CNAME {
domain: domain, domain: domain,
addr: addr, host: cname,
ttl: ttl ttl: ttl,
}) })
}, }
QueryType::MX => {
let priority = buffer.read_u16()?;
let mut mx = String::new();
buffer.read_qname(&mut mx)?;
// NS and CNAME both have the same structure. Ok(DnsRecord::MX {
QueryType::NS => { domain: domain,
let mut ns = String::new(); priority: priority,
try!(buffer.read_qname(&mut ns)); host: mx,
ttl: ttl,
})
}
QueryType::UNKNOWN(_) => {
buffer.step(data_len as usize)?;
Ok(DnsRecord::NS { Ok(DnsRecord::UNKNOWN {
domain: domain, domain: domain,
host: ns, qtype: qtype_num,
ttl: ttl data_len: data_len,
}) ttl: ttl,
}, })
}
QueryType::CNAME => {
let mut cname = String::new();
try!(buffer.read_qname(&mut cname));
Ok(DnsRecord::CNAME {
domain: domain,
host: cname,
ttl: ttl
})
},
// MX is almost like the previous two, but with one extra field for priority.
QueryType::MX => {
let priority = try!(buffer.read_u16());
let mut mx = String::new();
try!(buffer.read_qname(&mut mx));
Ok(DnsRecord::MX {
domain: domain,
priority: priority,
host: mx,
ttl: ttl
})
},
// And we end with some code for handling unknown record types, as before.
QueryType::UNKNOWN(_) => {
try!(buffer.step(data_len as usize));
Ok(DnsRecord::UNKNOWN {
domain: domain,
qtype: qtype_num,
data_len: data_len,
ttl: ttl
})
} }
} }
- snip -
} }
``` ```
@ -280,8 +275,8 @@ impl BytePacketBuffer {
} }
fn set_u16(&mut self, pos: usize, val: u16) -> Result<()> { fn set_u16(&mut self, pos: usize, val: u16) -> Result<()> {
try!(self.set(pos,(val >> 8) as u8)); self.set(pos, (val >> 8) as u8)?;
try!(self.set(pos+1,(val & 0xFF) as u8)); self.set(pos + 1, (val & 0xFF) as u8)?;
Ok(()) Ok(())
} }
@ -289,89 +284,119 @@ impl BytePacketBuffer {
} }
``` ```
When writing the labels of a record, we don't know ahead of time the number of
bytes needed, since we might end up using jumps to compress the size. We'll
solve this by writing a zero size and then going back to fill in the size
needed.
### Extending DnsRecord for writing new record types ### Extending DnsRecord for writing new record types
Now we can amend `DnsRecord::write`. Here's our new function: Now we can amend `DnsRecord::write`. Here's our new function:
```rust ```rust
pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<usize> { impl DnsRecord {
let start_pos = buffer.pos(); - snip -
match *self { pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<usize> {
DnsRecord::A { ref domain, ref addr, ttl } => { let start_pos = buffer.pos();
try!(buffer.write_qname(domain));
try!(buffer.write_u16(QueryType::A.to_num()));
try!(buffer.write_u16(1));
try!(buffer.write_u32(ttl));
try!(buffer.write_u16(4));
let octets = addr.octets(); match *self {
try!(buffer.write_u8(octets[0])); DnsRecord::A {
try!(buffer.write_u8(octets[1])); ref domain,
try!(buffer.write_u8(octets[2])); ref addr,
try!(buffer.write_u8(octets[3])); ttl,
}, } => {
DnsRecord::NS { ref domain, ref host, ttl } => { buffer.write_qname(domain)?;
try!(buffer.write_qname(domain)); buffer.write_u16(QueryType::A.to_num())?;
try!(buffer.write_u16(QueryType::NS.to_num())); buffer.write_u16(1)?;
try!(buffer.write_u16(1)); buffer.write_u32(ttl)?;
try!(buffer.write_u32(ttl)); buffer.write_u16(4)?;
let pos = buffer.pos(); let octets = addr.octets();
try!(buffer.write_u16(0)); buffer.write_u8(octets[0])?;
buffer.write_u8(octets[1])?;
try!(buffer.write_qname(host)); buffer.write_u8(octets[2])?;
buffer.write_u8(octets[3])?;
let size = buffer.pos() - (pos + 2);
try!(buffer.set_u16(pos, size as u16));
},
DnsRecord::CNAME { ref domain, ref host, ttl } => {
try!(buffer.write_qname(domain));
try!(buffer.write_u16(QueryType::CNAME.to_num()));
try!(buffer.write_u16(1));
try!(buffer.write_u32(ttl));
let pos = buffer.pos();
try!(buffer.write_u16(0));
try!(buffer.write_qname(host));
let size = buffer.pos() - (pos + 2);
try!(buffer.set_u16(pos, size as u16));
},
DnsRecord::MX { ref domain, priority, ref host, ttl } => {
try!(buffer.write_qname(domain));
try!(buffer.write_u16(QueryType::MX.to_num()));
try!(buffer.write_u16(1));
try!(buffer.write_u32(ttl));
let pos = buffer.pos();
try!(buffer.write_u16(0));
try!(buffer.write_u16(priority));
try!(buffer.write_qname(host));
let size = buffer.pos() - (pos + 2);
try!(buffer.set_u16(pos, size as u16));
},
DnsRecord::AAAA { ref domain, ref addr, ttl } => {
try!(buffer.write_qname(domain));
try!(buffer.write_u16(QueryType::AAAA.to_num()));
try!(buffer.write_u16(1));
try!(buffer.write_u32(ttl));
try!(buffer.write_u16(16));
for octet in &addr.segments() {
try!(buffer.write_u16(*octet));
} }
}, DnsRecord::NS {
DnsRecord::UNKNOWN { .. } => { ref domain,
println!("Skipping record: {:?}", self); ref host,
} ttl,
} } => {
buffer.write_qname(domain)?;
buffer.write_u16(QueryType::NS.to_num())?;
buffer.write_u16(1)?;
buffer.write_u32(ttl)?;
Ok(buffer.pos() - start_pos) let pos = buffer.pos();
buffer.write_u16(0)?;
buffer.write_qname(host)?;
let size = buffer.pos() - (pos + 2);
buffer.set_u16(pos, size as u16)?;
}
DnsRecord::CNAME {
ref domain,
ref host,
ttl,
} => {
buffer.write_qname(domain)?;
buffer.write_u16(QueryType::CNAME.to_num())?;
buffer.write_u16(1)?;
buffer.write_u32(ttl)?;
let pos = buffer.pos();
buffer.write_u16(0)?;
buffer.write_qname(host)?;
let size = buffer.pos() - (pos + 2);
buffer.set_u16(pos, size as u16)?;
}
DnsRecord::MX {
ref domain,
priority,
ref host,
ttl,
} => {
buffer.write_qname(domain)?;
buffer.write_u16(QueryType::MX.to_num())?;
buffer.write_u16(1)?;
buffer.write_u32(ttl)?;
let pos = buffer.pos();
buffer.write_u16(0)?;
buffer.write_u16(priority)?;
buffer.write_qname(host)?;
let size = buffer.pos() - (pos + 2);
buffer.set_u16(pos, size as u16)?;
}
DnsRecord::AAAA {
ref domain,
ref addr,
ttl,
} => {
buffer.write_qname(domain)?;
buffer.write_u16(QueryType::AAAA.to_num())?;
buffer.write_u16(1)?;
buffer.write_u32(ttl)?;
buffer.write_u16(16)?;
for octet in &addr.segments() {
buffer.write_u16(*octet)?;
}
}
DnsRecord::UNKNOWN { .. } => {
println!("Skipping record: {:?}", self);
}
}
Ok(buffer.pos() - start_pos)
}
} }
``` ```

View File

@ -141,26 +141,27 @@ work, it's a rather quick effort!
We'll start out by doing some quick refactoring, moving our lookup code into We'll start out by doing some quick refactoring, moving our lookup code into
a separate function. This is for the most part the same code as we had in our a separate function. This is for the most part the same code as we had in our
`main` function in the previous chapter, with the only change being that we `main` function in the previous chapter.
handle errors gracefully using `try!`.
```rust ```rust
fn lookup(qname: &str, qtype: QueryType, server: (&str, u16)) -> Result<DnsPacket> { fn lookup(qname: &str, qtype: QueryType, server: (&str, u16)) -> Result<DnsPacket> {
let socket = try!(UdpSocket::bind(("0.0.0.0", 43210))); let socket = UdpSocket::bind(("0.0.0.0", 43210))?;
let mut packet = DnsPacket::new(); let mut packet = DnsPacket::new();
packet.header.id = 6666; packet.header.id = 6666;
packet.header.questions = 1; packet.header.questions = 1;
packet.header.recursion_desired = true; packet.header.recursion_desired = true;
packet.questions.push(DnsQuestion::new(qname.to_string(), qtype)); packet
.questions
.push(DnsQuestion::new(qname.to_string(), qtype));
let mut req_buffer = BytePacketBuffer::new(); let mut req_buffer = BytePacketBuffer::new();
packet.write(&mut req_buffer).unwrap(); packet.write(&mut req_buffer)?;
try!(socket.send_to(&req_buffer.buf[0..req_buffer.pos], server)); socket.send_to(&req_buffer.buf[0..req_buffer.pos], server)?;
let mut res_buffer = BytePacketBuffer::new(); let mut res_buffer = BytePacketBuffer::new();
socket.recv_from(&mut res_buffer.buf).unwrap(); socket.recv_from(&mut res_buffer.buf)?;
DnsPacket::from_buffer(&mut res_buffer) DnsPacket::from_buffer(&mut res_buffer)
} }
@ -171,12 +172,12 @@ fn lookup(qname: &str, qtype: QueryType, server: (&str, u16)) -> Result<DnsPacke
Now we'll write our server code. First, we need get some things in order. Now we'll write our server code. First, we need get some things in order.
```rust ```rust
fn main() { fn main() -> Result<()> {
// Forward queries to Google's public DNS // Forward queries to Google's public DNS
let server = ("8.8.8.8", 53); let server = ("8.8.8.8", 53);
// Bind an UDP socket on port 2053 // Bind an UDP socket on port 2053
let socket = UdpSocket::bind(("0.0.0.0", 2053)).unwrap(); let socket = UdpSocket::bind(("0.0.0.0", 2053))?;
// For now, queries are handled sequentially, so an infinite loop for servicing // For now, queries are handled sequentially, so an infinite loop for servicing
// requests is initiated. // requests is initiated.
@ -224,7 +225,6 @@ fn main() {
if request.questions.is_empty() { if request.questions.is_empty() {
packet.header.rescode = ResultCode::FORMERR; packet.header.rescode = ResultCode::FORMERR;
} }
// Usually a question will be present, though. // Usually a question will be present, though.
else { else {
let question = &request.questions[0]; let question = &request.questions[0];
@ -254,37 +254,36 @@ fn main() {
} else { } else {
packet.header.rescode = ResultCode::SERVFAIL; packet.header.rescode = ResultCode::SERVFAIL;
} }
// The only thing remaining is to encode our response and send it off!
let mut res_buffer = BytePacketBuffer::new();
match packet.write(&mut res_buffer) {
Ok(_) => {},
Err(e) => {
println!("Failed to encode UDP response packet: {:?}", e);
continue;
}
};
let len = res_buffer.pos();
let data = match res_buffer.get_range(0, len) {
Ok(x) => x,
Err(e) => {
println!("Failed to retrieve response buffer: {:?}", e);
continue;
}
};
match socket.send_to(data, src) {
Ok(_) => {},
Err(e) => {
println!("Failed to send response buffer: {:?}", e);
continue;
}
};
} }
} // End of request loop
} // End of main // The only thing remaining is to encode our response and send it off!
let mut res_buffer = BytePacketBuffer::new();
match packet.write(&mut res_buffer) {
Ok(_) => {}
Err(e) => {
println!("Failed to encode UDP response packet: {:?}", e);
continue;
}
};
let len = res_buffer.pos();
let data = match res_buffer.get_range(0, len) {
Ok(x) => x,
Err(e) => {
println!("Failed to retrieve response buffer: {:?}", e);
continue;
}
};
match socket.send_to(data, src) {
Ok(_) => {}
Err(e) => {
println!("Failed to send response buffer: {:?}", e);
continue;
}
};
}
}
``` ```
The match idiom for error handling is used again and again here, since we want to avoid The match idiom for error handling is used again and again here, since we want to avoid

View File

@ -168,89 +168,68 @@ impl DnsPacket {
- snip - - snip -
// It's useful to be able to pick a random A record from a packet. When we /// It's useful to be able to pick a random A record from a packet. When we
// get multiple IP's for a single name, it doesn't matter which one we /// get multiple IP's for a single name, it doesn't matter which one we
// choose, so in those cases we can now pick one at random. /// choose, so in those cases we can now pick one at random.
pub fn get_random_a(&self) -> Option<String> { pub fn get_random_a(&self) -> Option<String> {
if !self.answers.is_empty() { self.answers
let idx = random::<usize>() % self.answers.len(); .iter()
let a_record = &self.answers[idx]; .filter_map(|record| match record {
if let DnsRecord::A{ ref addr, .. } = *a_record { DnsRecord::A { ref addr, .. } => Some(addr.to_string()),
return Some(addr.to_string()); _ => None,
} })
} .next()
None
} }
// We'll use the fact that name servers often bundle the corresponding /// A helper function which returns an iterator over all name servers in
// A records when replying to an NS query to implement a function that returns /// the authorities section, represented as (domain, host) tuples
// the actual IP for an NS record if possible. fn get_ns<'a>(&'a self, qname: &'a str) -> impl Iterator<Item=(&'a str, &'a str)> {
self.authorities.iter()
// In practice, these are always NS records in well formed packages.
// Convert the NS records to a tuple which has only the data we need
// to make it easy to work with.
.filter_map(|record| match record {
DnsRecord::NS { domain, host, .. } => Some((domain.as_str(), host.as_str())),
_ => None,
})
// Discard servers which aren't authoritative to our query
.filter(move |(domain, _)| qname.ends_with(*domain))
}
/// When there is a NS record in the authorities section, there may also
/// be a matching A record in the additional section. This saves us
/// from doing a separate query to resolve the IP of the name server.
pub fn get_resolved_ns(&self, qname: &str) -> Option<String> { pub fn get_resolved_ns(&self, qname: &str) -> Option<String> {
// Get an iterator over the nameservers in the authorities section
self.get_ns(qname)
// Now we need to look for a matching A record in the additional
// section. Since we just want the first valid record, we can just
// build a stream of matching records.
.flat_map(|(_, host)| {
self.resources.iter()
// Filter for A records where the domain match the host
// of the NS record that we are currently processing
.filter_map(move |record| match record {
DnsRecord::A { domain, addr, .. } if domain == host => Some(addr),
_ => None,
})
})
.map(|addr| addr.to_string())
// Finally, pick the first valid entry
.next()
}
// First, we scan the list of NS records in the authorities section: /// However, not all name servers are as that nice. In certain cases there won't
let mut new_authorities = Vec::new(); /// be any A records in the additional section, and we'll have to perform *another*
for auth in &self.authorities { /// lookup in the midst of our first. For this, we introduce a method for
if let DnsRecord::NS { ref domain, ref host, .. } = *auth { returning the hostname of an appropriate name server.
if !qname.ends_with(domain) {
continue;
}
// Once we've found an NS record, we scan the resources record for a matching
// A record...
for rsrc in &self.resources {
if let DnsRecord::A{ ref domain, ref addr, ttl } = *rsrc {
if domain != host {
continue;
}
let rec = DnsRecord::A {
domain: host.clone(),
addr: *addr,
ttl: ttl
};
// ...and push any matches to a list.
new_authorities.push(rec);
}
}
}
}
// If there are any matches, we pick the first one.
if !new_authorities.is_empty() {
if let DnsRecord::A { addr, .. } = new_authorities[0] {
return Some(addr.to_string());
}
}
None
} // End of get_resolved_ns
// However, not all name servers are as that nice. In certain cases there won't
// be any A records in the additional section, and we'll have to perform *another*
// lookup in the midst. For this, we introduce a method for returning the host
// name of an appropriate name server.
pub fn get_unresolved_ns(&self, qname: &str) -> Option<String> { pub fn get_unresolved_ns(&self, qname: &str) -> Option<String> {
// Get an iterator over the nameservers in the authorities section
let mut new_authorities = Vec::new(); self.get_ns(qname)
for auth in &self.authorities { .map(|(_, host)| host.to_string())
if let DnsRecord::NS { ref domain, ref host, .. } = *auth { // Finally, pick the first valid entry
if !qname.ends_with(domain) { .next()
continue; }
}
new_authorities.push(host);
}
}
if !new_authorities.is_empty() {
let idx = random::<usize>() % new_authorities.len();
return Some(new_authorities[idx].clone());
}
None
} // End of get_unresolved_ns
} // End of DnsPacket } // End of DnsPacket
``` ```
@ -273,12 +252,10 @@ fn recursive_lookup(qname: &str, qtype: QueryType) -> Result<DnsPacket> {
let ns_copy = ns.clone(); let ns_copy = ns.clone();
let server = (ns_copy.as_str(), 53); let server = (ns_copy.as_str(), 53);
let response = try!(lookup(qname, qtype.clone(), server)); let response = lookup(qname, qtype.clone(), server)?;
// If there are entries in the answer section, and no errors, we are done! // If there are entries in the answer section, and no errors, we are done!
if !response.answers.is_empty() && if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
response.header.rescode == ResultCode::NOERROR {
return Ok(response.clone()); return Ok(response.clone());
} }
@ -301,23 +278,23 @@ fn recursive_lookup(qname: &str, qtype: QueryType) -> Result<DnsPacket> {
// we'll go with what the last server told us. // we'll go with what the last server told us.
let new_ns_name = match response.get_unresolved_ns(qname) { let new_ns_name = match response.get_unresolved_ns(qname) {
Some(x) => x, Some(x) => x,
None => return Ok(response.clone()) None => return Ok(response.clone()),
}; };
// Here we go down the rabbit hole by starting _another_ lookup sequence in the // Here we go down the rabbit hole by starting _another_ lookup sequence in the
// midst of our current one. Hopefully, this will give us the IP of an appropriate // midst of our current one. Hopefully, this will give us the IP of an appropriate
// name server. // name server.
let recursive_response = try!(recursive_lookup(&new_ns_name, QueryType::A)); let recursive_response = recursive_lookup(&new_ns_name, QueryType::A)?;
// Finally, we pick a random ip from the result, and restart the loop. If no such // Finally, we pick a random ip from the result, and restart the loop. If no such
// record is available, we again return the last result we got. // record is available, we again return the last result we got.
if let Some(new_ns) = recursive_response.get_random_a() { if let Some(new_ns) = recursive_response.get_random_a() {
ns = new_ns.clone(); ns = new_ns.clone();
} else { } else {
return Ok(response.clone()) return Ok(response.clone());
} }
} }
} // End of recursive_lookup }
``` ```
### Trying out recursive lookup ### Trying out recursive lookup
@ -326,7 +303,7 @@ The only thing remaining is to change our main function to use
`recursive_lookup`: `recursive_lookup`:
```rust ```rust
fn main() { fn main() -> Result<()> {
- snip - - snip -

View File

@ -11,6 +11,8 @@ pub struct BytePacketBuffer {
} }
impl BytePacketBuffer { impl BytePacketBuffer {
/// This gives us a fresh buffer for holding the packet contents, and a
/// field for keeping track of where we are.
pub fn new() -> BytePacketBuffer { pub fn new() -> BytePacketBuffer {
BytePacketBuffer { BytePacketBuffer {
buf: [0; 512], buf: [0; 512],
@ -18,22 +20,26 @@ impl BytePacketBuffer {
} }
} }
/// Current position within buffer
fn pos(&self) -> usize { fn pos(&self) -> usize {
self.pos self.pos
} }
/// Step the buffer position forward a specific number of steps
fn step(&mut self, steps: usize) -> Result<()> { fn step(&mut self, steps: usize) -> Result<()> {
self.pos += steps; self.pos += steps;
Ok(()) Ok(())
} }
/// Change the buffer position
fn seek(&mut self, pos: usize) -> Result<()> { fn seek(&mut self, pos: usize) -> Result<()> {
self.pos = pos; self.pos = pos;
Ok(()) Ok(())
} }
/// Read a single byte and move the position one step forward
fn read(&mut self) -> Result<u8> { fn read(&mut self) -> Result<u8> {
if self.pos >= 512 { if self.pos >= 512 {
return Err("End of buffer".into()); return Err("End of buffer".into());
@ -44,6 +50,7 @@ impl BytePacketBuffer {
Ok(res) Ok(res)
} }
/// Get a single byte, without changing the buffer position
fn get(&mut self, pos: usize) -> Result<u8> { fn get(&mut self, pos: usize) -> Result<u8> {
if pos >= 512 { if pos >= 512 {
return Err("End of buffer".into()); return Err("End of buffer".into());
@ -51,6 +58,7 @@ impl BytePacketBuffer {
Ok(self.buf[pos]) Ok(self.buf[pos])
} }
/// Get a range of bytes
fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> { fn get_range(&mut self, start: usize, len: usize) -> Result<&[u8]> {
if start + len >= 512 { if start + len >= 512 {
return Err("End of buffer".into()); return Err("End of buffer".into());
@ -58,12 +66,14 @@ impl BytePacketBuffer {
Ok(&self.buf[start..start + len as usize]) Ok(&self.buf[start..start + len as usize])
} }
/// Read two bytes, stepping two steps forward
fn read_u16(&mut self) -> Result<u16> { fn read_u16(&mut self) -> Result<u16> {
let res = ((self.read()? as u16) << 8) | (self.read()? as u16); let res = ((self.read()? as u16) << 8) | (self.read()? as u16);
Ok(res) Ok(res)
} }
/// Read four bytes, stepping four steps forward
fn read_u32(&mut self) -> Result<u32> { fn read_u32(&mut self) -> Result<u32> {
let res = ((self.read()? as u32) << 24) let res = ((self.read()? as u32) << 24)
| ((self.read()? as u32) << 16) | ((self.read()? as u32) << 16)
@ -73,13 +83,28 @@ impl BytePacketBuffer {
Ok(res) Ok(res)
} }
/// Read a qname
///
/// The tricky part: Reading domain names, taking labels into consideration.
/// Will take something like [3]www[6]google[3]com[0] and append
/// www.google.com to outstr.
fn read_qname(&mut self, outstr: &mut String) -> Result<()> { fn read_qname(&mut self, outstr: &mut String) -> Result<()> {
// Since we might encounter jumps, we'll keep track of our position
// locally as opposed to using the position within the struct. This
// allows us to move the shared position to a point past our current
// qname, while keeping track of our progress on the current qname
// using this variable.
let mut pos = self.pos(); let mut pos = self.pos();
let mut jumped = false;
let mut delim = ""; // track whether or not we've jumped
let mut jumped = false;
let max_jumps = 5; let max_jumps = 5;
let mut jumps_performed = 0; let mut jumps_performed = 0;
// Our delimiter which we append for each label. Since we don't want a
// dot at the beginning of the domain name we'll leave it empty for now
// and set it to "." at the end of the first iteration.
let mut delim = "";
loop { loop {
// Dns Packets are untrusted data, so we need to be paranoid. Someone // Dns Packets are untrusted data, so we need to be paranoid. Someone
// can craft a packet with a cycle in the jump instructions. This guards // can craft a packet with a cycle in the jump instructions. This guards
@ -88,42 +113,56 @@ impl BytePacketBuffer {
return Err(format!("Limit of {} jumps exceeded", max_jumps).into()); return Err(format!("Limit of {} jumps exceeded", max_jumps).into());
} }
// At this point, we're always at the beginning of a label. Recall
// that labels start with a length byte.
let len = self.get(pos)?; let len = self.get(pos)?;
// A two byte sequence, where the two highest bits of the first byte is // If len has the two most significant bit are set, it represents a
// set, represents a offset relative to the start of the buffer. We // jump to some other offset in the packet:
// handle this by jumping to the offset, setting a flag to indicate
// that we shouldn't update the shared buffer position once done.
if (len & 0xC0) == 0xC0 { if (len & 0xC0) == 0xC0 {
// When a jump is performed, we only modify the shared buffer // Update the buffer position to a point past the current
// position once, and avoid making the change later on. // label. We don't need to touch it any further.
if !jumped { if !jumped {
self.seek(pos + 2)?; self.seek(pos + 2)?;
} }
// Read another byte, calculate offset and perform the jump by
// updating our local position variable
let b2 = self.get(pos + 1)? as u16; let b2 = self.get(pos + 1)? as u16;
let offset = (((len as u16) ^ 0xC0) << 8) | b2; let offset = (((len as u16) ^ 0xC0) << 8) | b2;
pos = offset as usize; pos = offset as usize;
// Indicate that a jump was performed.
jumped = true; jumped = true;
jumps_performed += 1; jumps_performed += 1;
continue; continue;
} }
// The base scenario, where we're reading a single label and
// appending it to the output:
else {
// Move a single byte forward to move past the length byte.
pos += 1;
pos += 1; // Domain names are terminated by an empty label of length 0,
// so if the length is zero we're done.
if len == 0 {
break;
}
// Names are terminated by an empty label of length 0 // Append the delimiter to our output buffer first.
if len == 0 { outstr.push_str(delim);
break;
// Extract the actual ASCII bytes for this label and append them
// to the output buffer.
let str_buffer = self.get_range(pos, len as usize)?;
outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase());
delim = ".";
// Move forward the full length of the label.
pos += len as usize;
} }
outstr.push_str(delim);
let str_buffer = self.get_range(pos, len as usize)?;
outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase());
delim = ".";
pos += len as usize;
} }
if !jumped { if !jumped {
@ -386,19 +425,19 @@ fn main() -> Result<()> {
f.read(&mut buffer.buf)?; f.read(&mut buffer.buf)?;
let packet = DnsPacket::from_buffer(&mut buffer)?; let packet = DnsPacket::from_buffer(&mut buffer)?;
println!("{:?}", packet.header); println!("{:#?}", packet.header);
for q in packet.questions { for q in packet.questions {
println!("{:?}", q); println!("{:#?}", q);
} }
for rec in packet.answers { for rec in packet.answers {
println!("{:?}", rec); println!("{:#?}", rec);
} }
for rec in packet.authorities { for rec in packet.authorities {
println!("{:?}", rec); println!("{:#?}", rec);
} }
for rec in packet.resources { for rec in packet.resources {
println!("{:?}", rec); println!("{:#?}", rec);
} }
Ok(()) Ok(())

View File

@ -164,9 +164,7 @@ impl BytePacketBuffer {
} }
fn write_qname(&mut self, qname: &str) -> Result<()> { fn write_qname(&mut self, qname: &str) -> Result<()> {
let split_str = qname.split('.').collect::<Vec<&str>>(); for label in qname.split('.') {
for label in split_str {
let len = label.len(); let len = label.len();
if len > 0x34 { if len > 0x34 {
return Err("Single label exceeds 63 characters of length".into()); return Err("Single label exceeds 63 characters of length".into());
@ -521,12 +519,18 @@ impl DnsPacket {
} }
fn main() -> Result<()> { fn main() -> Result<()> {
let qname = "www.yahoo.com"; // Perform an A query for google.com
let qname = "google.com";
let qtype = QueryType::A; let qtype = QueryType::A;
// Using googles public DNS server
let server = ("8.8.8.8", 53); let server = ("8.8.8.8", 53);
// Bind a UDP socket to an arbitrary port
let socket = UdpSocket::bind(("0.0.0.0", 43210))?; let socket = UdpSocket::bind(("0.0.0.0", 43210))?;
// Build our query packet. It's important that we remember to set the
// `recursion_desired` flag. As noted earlier, the packet id is arbitrary.
let mut packet = DnsPacket::new(); let mut packet = DnsPacket::new();
packet.header.id = 6666; packet.header.id = 6666;
@ -536,27 +540,34 @@ fn main() -> Result<()> {
.questions .questions
.push(DnsQuestion::new(qname.to_string(), qtype)); .push(DnsQuestion::new(qname.to_string(), qtype));
// Use our new write method to write the packet to a buffer...
let mut req_buffer = BytePacketBuffer::new(); let mut req_buffer = BytePacketBuffer::new();
packet.write(&mut req_buffer)?; packet.write(&mut req_buffer)?;
// ...and send it off to the server using our socket:
socket.send_to(&req_buffer.buf[0..req_buffer.pos], server)?; socket.send_to(&req_buffer.buf[0..req_buffer.pos], server)?;
// To prepare for receiving the response, we'll create a new `BytePacketBuffer`,
// and ask the socket to write the response directly into our buffer.
let mut res_buffer = BytePacketBuffer::new(); let mut res_buffer = BytePacketBuffer::new();
socket.recv_from(&mut res_buffer.buf)?; socket.recv_from(&mut res_buffer.buf)?;
// As per the previous section, `DnsPacket::from_buffer()` is then used to
// actually parse the packet after which we can print the response.
let res_packet = DnsPacket::from_buffer(&mut res_buffer)?; let res_packet = DnsPacket::from_buffer(&mut res_buffer)?;
println!("{:?}", res_packet.header); println!("{:#?}", res_packet.header);
for q in res_packet.questions { for q in res_packet.questions {
println!("{:?}", q); println!("{:#?}", q);
} }
for rec in res_packet.answers { for rec in res_packet.answers {
println!("{:?}", rec); println!("{:#?}", rec);
} }
for rec in res_packet.authorities { for rec in res_packet.authorities {
println!("{:?}", rec); println!("{:#?}", rec);
} }
for rec in res_packet.resources { for rec in res_packet.resources {
println!("{:?}", rec); println!("{:#?}", rec);
} }
Ok(()) Ok(())

View File

@ -164,9 +164,7 @@ impl BytePacketBuffer {
} }
fn write_qname(&mut self, qname: &str) -> Result<()> { fn write_qname(&mut self, qname: &str) -> Result<()> {
let split_str = qname.split('.').collect::<Vec<&str>>(); for label in qname.split('.') {
for label in split_str {
let len = label.len(); let len = label.len();
if len > 0x34 { if len > 0x34 {
return Err("Single label exceeds 63 characters of length".into()); return Err("Single label exceeds 63 characters of length".into());

View File

@ -164,9 +164,7 @@ impl BytePacketBuffer {
} }
fn write_qname(&mut self, qname: &str) -> Result<()> { fn write_qname(&mut self, qname: &str) -> Result<()> {
let split_str = qname.split('.').collect::<Vec<&str>>(); for label in qname.split('.') {
for label in split_str {
let len = label.len(); let len = label.len();
if len > 0x34 { if len > 0x34 {
return Err("Single label exceeds 63 characters of length".into()); return Err("Single label exceeds 63 characters of length".into());
@ -714,11 +712,17 @@ fn lookup(qname: &str, qtype: QueryType, server: (&str, u16)) -> Result<DnsPacke
} }
fn main() -> Result<()> { fn main() -> Result<()> {
// Forward queries to Google's public DNS
let server = ("8.8.8.8", 53); let server = ("8.8.8.8", 53);
// Bind an UDP socket on port 2053
let socket = UdpSocket::bind(("0.0.0.0", 2053))?; let socket = UdpSocket::bind(("0.0.0.0", 2053))?;
// For now, queries are handled sequentially, so an infinite loop for servicing
// requests is initiated.
loop { loop {
// With a socket ready, we can go ahead and read a packet. This will
// block until one is received.
let mut req_buffer = BytePacketBuffer::new(); let mut req_buffer = BytePacketBuffer::new();
let (_, src) = match socket.recv_from(&mut req_buffer.buf) { let (_, src) = match socket.recv_from(&mut req_buffer.buf) {
Ok(x) => x, Ok(x) => x,
@ -728,6 +732,16 @@ fn main() -> Result<()> {
} }
}; };
// Here we use match to safely unwrap the `Result`. If everything's as expected,
// the raw bytes are simply returned, and if not it'll abort by restarting the
// loop and waiting for the next request. The `recv_from` function will write the
// data into the provided buffer, and return the length of the data read as well
// as the source address. We're not interested in the length, but we need to keep
// track of the source in order to send our reply later on.
// Next, `DnsPacket::from_buffer` is used to parse the raw bytes into
// a `DnsPacket`. It uses the same error handling idiom as the previous statement.
let request = match DnsPacket::from_buffer(&mut req_buffer) { let request = match DnsPacket::from_buffer(&mut req_buffer) {
Ok(x) => x, Ok(x) => x,
Err(e) => { Err(e) => {
@ -736,18 +750,29 @@ fn main() -> Result<()> {
} }
}; };
// Create and initialize the response packet
let mut packet = DnsPacket::new(); let mut packet = DnsPacket::new();
packet.header.id = request.header.id; packet.header.id = request.header.id;
packet.header.recursion_desired = true; packet.header.recursion_desired = true;
packet.header.recursion_available = true; packet.header.recursion_available = true;
packet.header.response = true; packet.header.response = true;
// Being mindful of how unreliable input data from arbitrary senders can be, we
// need make sure that a question is actually present. If not, we return `FORMERR`
// to indicate that the sender made something wrong.
if request.questions.is_empty() { if request.questions.is_empty() {
packet.header.rescode = ResultCode::FORMERR; packet.header.rescode = ResultCode::FORMERR;
} else { }
// Usually a question will be present, though.
else {
let question = &request.questions[0]; let question = &request.questions[0];
println!("Received query: {:?}", question); println!("Received query: {:?}", question);
// Since all is set up and as expected, the query can be forwarded to the target
// server. There's always the possibility that the query will fail, in which case
// the `SERVFAIL` response code is set to indicate as much to the client. If
// rather everything goes as planned, the question and response records as copied
// into our response packet.
if let Ok(result) = lookup(&question.name, question.qtype, server) { if let Ok(result) = lookup(&question.name, question.qtype, server) {
packet.questions.push(question.clone()); packet.questions.push(question.clone());
packet.header.rescode = result.header.rescode; packet.header.rescode = result.header.rescode;
@ -769,6 +794,7 @@ fn main() -> Result<()> {
} }
} }
// The only thing remaining is to encode our response and send it off!
let mut res_buffer = BytePacketBuffer::new(); let mut res_buffer = BytePacketBuffer::new();
match packet.write(&mut res_buffer) { match packet.write(&mut res_buffer) {
Ok(_) => {} Ok(_) => {}

View File

@ -164,9 +164,7 @@ impl BytePacketBuffer {
} }
fn write_qname(&mut self, qname: &str) -> Result<()> { fn write_qname(&mut self, qname: &str) -> Result<()> {
let split_str = qname.split('.').collect::<Vec<&str>>(); for label in qname.split('.') {
for label in split_str {
let len = label.len(); let len = label.len();
if len > 0x34 { if len > 0x34 {
return Err("Single label exceeds 63 characters of length".into()); return Err("Single label exceeds 63 characters of length".into());
@ -690,84 +688,69 @@ impl DnsPacket {
Ok(()) Ok(())
} }
/// It's useful to be able to pick a random A record from a packet. When we
/// get multiple IP's for a single name, it doesn't matter which one we
/// choose, so in those cases we can now pick one at random.
pub fn get_random_a(&self) -> Option<String> { pub fn get_random_a(&self) -> Option<String> {
if !self.answers.is_empty() { self.answers
let a_record = &self.answers[0]; .iter()
if let DnsRecord::A { ref addr, .. } = *a_record { .filter_map(|record| match record {
return Some(addr.to_string()); DnsRecord::A { ref addr, .. } => Some(addr.to_string()),
} _ => None,
} })
.next()
None
} }
/// A helper function which returns an iterator over all name servers in
/// the authorities section, represented as (domain, host) tuples
fn get_ns<'a>(&'a self, qname: &'a str) -> impl Iterator<Item = (&'a str, &'a str)> {
self.authorities
.iter()
// In practice, these are always NS records in well formed packages.
// Convert the NS records to a tuple which has only the data we need
// to make it easy to work with.
.filter_map(|record| match record {
DnsRecord::NS { domain, host, .. } => Some((domain.as_str(), host.as_str())),
_ => None,
})
// Discard servers which aren't authoritative to our query
.filter(move |(domain, _)| qname.ends_with(*domain))
}
/// We'll use the fact that name servers often bundle the corresponding
/// A records when replying to an NS query to implement a function that
/// returns the actual IP for an NS record if possible.
pub fn get_resolved_ns(&self, qname: &str) -> Option<String> { pub fn get_resolved_ns(&self, qname: &str) -> Option<String> {
let mut new_authorities = Vec::new(); // Get an iterator over the nameservers in the authorities section
for auth in &self.authorities { self.get_ns(qname)
if let DnsRecord::NS { // Now we need to look for a matching A record in the additional
ref domain, // section. Since we just want the first valid record, we can just
ref host, // build a stream of matching records.
.. .flat_map(|(_, host)| {
} = *auth self.resources
{ .iter()
if !qname.ends_with(domain) { // Filter for A records where the domain match the host
continue; // of the NS record that we are currently processing
} .filter_map(move |record| match record {
DnsRecord::A { domain, addr, .. } if domain == host => Some(addr),
for rsrc in &self.resources { _ => None,
if let DnsRecord::A { })
ref domain, })
ref addr, .map(|addr| addr.to_string())
ttl, // Finally, pick the first valid entry
} = *rsrc .next()
{
if domain != host {
continue;
}
let rec = DnsRecord::A {
domain: host.clone(),
addr: *addr,
ttl: ttl,
};
new_authorities.push(rec);
}
}
}
}
if !new_authorities.is_empty() {
if let DnsRecord::A { addr, .. } = new_authorities[0] {
return Some(addr.to_string());
}
}
None
} }
/// However, not all name servers are as that nice. In certain cases there won't
/// be any A records in the additional section, and we'll have to perform *another*
/// lookup in the midst. For this, we introduce a method for returning the host
/// name of an appropriate name server.
pub fn get_unresolved_ns(&self, qname: &str) -> Option<String> { pub fn get_unresolved_ns(&self, qname: &str) -> Option<String> {
let mut new_authorities = Vec::new(); // Get an iterator over the nameservers in the authorities section
for auth in &self.authorities { self.get_ns(qname)
if let DnsRecord::NS { .map(|(_, host)| host.to_string())
ref domain, // Finally, pick the first valid entry
ref host, .next()
..
} = *auth
{
if !qname.ends_with(domain) {
continue;
}
new_authorities.push(host);
}
}
if !new_authorities.is_empty() {
return Some(new_authorities[0].clone());
}
None
} }
} }
@ -794,45 +777,53 @@ fn lookup(qname: &str, qtype: QueryType, server: (&str, u16)) -> Result<DnsPacke
} }
fn recursive_lookup(qname: &str, qtype: QueryType) -> Result<DnsPacket> { fn recursive_lookup(qname: &str, qtype: QueryType) -> Result<DnsPacket> {
// For now we're always starting with *a.root-servers.net*.
let mut ns = "198.41.0.4".to_string(); let mut ns = "198.41.0.4".to_string();
// Start querying name servers // Since it might take an arbitrary number of steps, we enter an unbounded loop.
loop { loop {
println!("attempting lookup of {:?} {} with ns {}", qtype, qname, ns); println!("attempting lookup of {:?} {} with ns {}", qtype, qname, ns);
// The next step is to send the query to the active server.
let ns_copy = ns.clone(); let ns_copy = ns.clone();
let server = (ns_copy.as_str(), 53); let server = (ns_copy.as_str(), 53);
let response = lookup(qname, qtype.clone(), server)?; let response = lookup(qname, qtype.clone(), server)?;
// If we've got an actual answer, we're done! // If there are entries in the answer section, and no errors, we are done!
if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR { if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
return Ok(response.clone()); return Ok(response.clone());
} }
// We might also get a `NXDOMAIN` reply, which is the authoritative name servers
// way of telling us that the name doesn't exist.
if response.header.rescode == ResultCode::NXDOMAIN { if response.header.rescode == ResultCode::NXDOMAIN {
return Ok(response.clone()); return Ok(response.clone());
} }
// Otherwise, try to find a new nameserver based on NS and a // Otherwise, we'll try to find a new nameserver based on NS and a corresponding A
// corresponding A record in the additional section // record in the additional section. If this succeeds, we can switch name server
// and retry the loop.
if let Some(new_ns) = response.get_resolved_ns(qname) { if let Some(new_ns) = response.get_resolved_ns(qname) {
// If there is such a record, we can retry the loop with that NS
ns = new_ns.clone(); ns = new_ns.clone();
continue; continue;
} }
// If not, we'll have to resolve the ip of a NS record // If not, we'll have to resolve the ip of a NS record. If no NS records exist,
// we'll go with what the last server told us.
let new_ns_name = match response.get_unresolved_ns(qname) { let new_ns_name = match response.get_unresolved_ns(qname) {
Some(x) => x, Some(x) => x,
None => return Ok(response.clone()), None => return Ok(response.clone()),
}; };
// Recursively resolve the NS // Here we go down the rabbit hole by starting _another_ lookup sequence in the
// midst of our current one. Hopefully, this will give us the IP of an appropriate
// name server.
let recursive_response = recursive_lookup(&new_ns_name, QueryType::A)?; let recursive_response = recursive_lookup(&new_ns_name, QueryType::A)?;
// Pick a random IP and restart // Finally, we pick a random ip from the result, and restart the loop. If no such
// record is available, we again return the last result we got.
if let Some(new_ns) = recursive_response.get_random_a() { if let Some(new_ns) = recursive_response.get_random_a() {
ns = new_ns.clone(); ns = new_ns.clone();
} else { } else {