Eliminate all useless allocation

This commit is contained in:
Emil Hernvall 2020-06-18 02:23:42 +02:00
parent 45cbfa59f7
commit 4202cae860
7 changed files with 81 additions and 75 deletions

View File

@ -97,7 +97,7 @@ impl DnsHeader {
)?; )?;
buffer.write_u8( buffer.write_u8(
(self.rescode.clone() as u8) (self.rescode as u8)
| ((self.checking_disabled as u8) << 4) | ((self.checking_disabled as u8) << 4)
| ((self.authed_data as u8) << 5) | ((self.authed_data as u8) << 5)
| ((self.z as u8) << 6) | ((self.z as u8) << 6)

View File

@ -190,7 +190,7 @@ fn handle_query(socket: &UdpSocket) -> Result<()> {
// Next, `DnsPacket::from_buffer` is used to parse the raw bytes into // Next, `DnsPacket::from_buffer` is used to parse the raw bytes into
// a `DnsPacket`. // a `DnsPacket`.
let request = DnsPacket::from_buffer(&mut req_buffer)?; let mut request = DnsPacket::from_buffer(&mut req_buffer)?;
// Create and initialize the response packet // Create and initialize the response packet
let mut packet = DnsPacket::new(); let mut packet = DnsPacket::new();
@ -199,15 +199,8 @@ fn handle_query(socket: &UdpSocket) -> Result<()> {
packet.header.recursion_available = true; packet.header.recursion_available = true;
packet.header.response = true; packet.header.response = true;
// Being mindful of how unreliable input data from arbitrary senders can be, we // In the normal case, exactly one question is present
// need make sure that a question is actually present. If not, we return `FORMERR` if let Some(question) = request.questions.pop() {
// to indicate that the sender made something wrong.
if request.questions.is_empty() {
packet.header.rescode = ResultCode::FORMERR;
}
// Usually a question will be present, though.
else {
let question = &request.questions[0];
println!("Received query: {:?}", question); println!("Received query: {:?}", question);
// Since all is set up and as expected, the query can be forwarded to the // Since all is set up and as expected, the query can be forwarded to the
@ -216,7 +209,7 @@ fn handle_query(socket: &UdpSocket) -> Result<()> {
// as much to the client. If rather everything goes as planned, the // as much to the client. If rather everything goes as planned, the
// question and response records as copied into our response packet. // question and response records as copied into our response packet.
if let Ok(result) = lookup(&question.name, question.qtype) { if let Ok(result) = lookup(&question.name, question.qtype) {
packet.questions.push(question.clone()); packet.questions.push(question);
packet.header.rescode = result.header.rescode; packet.header.rescode = result.header.rescode;
for rec in result.answers { for rec in result.answers {
@ -235,6 +228,12 @@ fn handle_query(socket: &UdpSocket) -> Result<()> {
packet.header.rescode = ResultCode::SERVFAIL; packet.header.rescode = ResultCode::SERVFAIL;
} }
} }
// Being mindful of how unreliable input data from arbitrary senders can be, we
// need make sure that a question is actually present. If not, we return `FORMERR`
// to indicate that the sender made something wrong.
else {
packet.header.rescode = ResultCode::FORMERR;
}
// The only thing remaining is to encode our response and send it off! // The only thing remaining is to encode our response and send it off!
let mut res_buffer = BytePacketBuffer::new(); let mut res_buffer = BytePacketBuffer::new();

View File

@ -171,11 +171,11 @@ impl DnsPacket {
/// It's useful to be able to pick a random A record from a packet. When we /// It's useful to be able to pick a random A record from a packet. When we
/// get multiple IP's for a single name, it doesn't matter which one we /// get multiple IP's for a single name, it doesn't matter which one we
/// choose, so in those cases we can now pick one at random. /// choose, so in those cases we can now pick one at random.
pub fn get_random_a(&self) -> Option<String> { pub fn get_random_a(&self) -> Option<Ipv4Addr> {
self.answers self.answers
.iter() .iter()
.filter_map(|record| match record { .filter_map(|record| match record {
DnsRecord::A { ref addr, .. } => Some(addr.to_string()), DnsRecord::A { addr, .. } => Some(*addr),
_ => None, _ => None,
}) })
.next() .next()
@ -183,8 +183,9 @@ impl DnsPacket {
/// A helper function which returns an iterator over all name servers in /// A helper function which returns an iterator over all name servers in
/// the authorities section, represented as (domain, host) tuples /// the authorities section, represented as (domain, host) tuples
fn get_ns<'a>(&'a self, qname: &'a str) -> impl Iterator<Item=(&'a str, &'a str)> { fn get_ns<'a>(&'a self, qname: &'a str) -> impl Iterator<Item = (&'a str, &'a str)> {
self.authorities.iter() self.authorities
.iter()
// In practice, these are always NS records in well formed packages. // In practice, these are always NS records in well formed packages.
// Convert the NS records to a tuple which has only the data we need // Convert the NS records to a tuple which has only the data we need
// to make it easy to work with. // to make it easy to work with.
@ -196,17 +197,18 @@ impl DnsPacket {
.filter(move |(domain, _)| qname.ends_with(*domain)) .filter(move |(domain, _)| qname.ends_with(*domain))
} }
/// When there is a NS record in the authorities section, there may also /// We'll use the fact that name servers often bundle the corresponding
/// be a matching A record in the additional section. This saves us /// A records when replying to an NS query to implement a function that
/// from doing a separate query to resolve the IP of the name server. /// returns the actual IP for an NS record if possible.
pub fn get_resolved_ns(&self, qname: &str) -> Option<String> { pub fn get_resolved_ns(&self, qname: &str) -> Option<Ipv4Addr> {
// Get an iterator over the nameservers in the authorities section // Get an iterator over the nameservers in the authorities section
self.get_ns(qname) self.get_ns(qname)
// Now we need to look for a matching A record in the additional // Now we need to look for a matching A record in the additional
// section. Since we just want the first valid record, we can just // section. Since we just want the first valid record, we can just
// build a stream of matching records. // build a stream of matching records.
.flat_map(|(_, host)| { .flat_map(|(_, host)| {
self.resources.iter() self.resources
.iter()
// Filter for A records where the domain match the host // Filter for A records where the domain match the host
// of the NS record that we are currently processing // of the NS record that we are currently processing
.filter_map(move |record| match record { .filter_map(move |record| match record {
@ -214,19 +216,19 @@ impl DnsPacket {
_ => None, _ => None,
}) })
}) })
.map(|addr| addr.to_string()) .map(|addr| *addr)
// Finally, pick the first valid entry // Finally, pick the first valid entry
.next() .next()
} }
/// However, not all name servers are as that nice. In certain cases there won't /// However, not all name servers are as that nice. In certain cases there won't
/// be any A records in the additional section, and we'll have to perform *another* /// be any A records in the additional section, and we'll have to perform *another*
/// lookup in the midst of our first. For this, we introduce a method for /// lookup in the midst. For this, we introduce a method for returning the host
/// returning the hostname of an appropriate name server. /// name of an appropriate name server.
pub fn get_unresolved_ns(&self, qname: &str) -> Option<String> { pub fn get_unresolved_ns<'a>(&'a self, qname: &'a str) -> Option<&'a str> {
// Get an iterator over the nameservers in the authorities section // Get an iterator over the nameservers in the authorities section
self.get_ns(qname) self.get_ns(qname)
.map(|(_, host)| host.to_string()) .map(|(_, host)| host)
// Finally, pick the first valid entry // Finally, pick the first valid entry
.next() .next()
} }
@ -240,36 +242,35 @@ We move swiftly on to our new `recursive_lookup` function:
```rust ```rust
fn recursive_lookup(qname: &str, qtype: QueryType) -> Result<DnsPacket> { fn recursive_lookup(qname: &str, qtype: QueryType) -> Result<DnsPacket> {
// For now we're always starting with *a.root-servers.net*. // For now we're always starting with *a.root-servers.net*.
let mut ns = "198.41.0.4".to_string(); let mut ns = "198.41.0.4".parse::<Ipv4Addr>().unwrap();
// Since it might take an arbitrary number of steps, we enter an unbounded loop. // Since it might take an arbitrary number of steps, we enter an unbounded loop.
loop { loop {
println!("attempting lookup of {:?} {} with ns {}", qtype, qname, ns); println!("attempting lookup of {:?} {} with ns {}", qtype, qname, ns);
// The next step is to send the query to the active server. // The next step is to send the query to the active server.
let ns_copy = ns.clone(); let ns_copy = ns;
let server = (ns_copy.as_str(), 53); let server = (ns_copy, 53);
let response = lookup(qname, qtype.clone(), server)?; let response = lookup(qname, qtype, server)?;
// If there are entries in the answer section, and no errors, we are done! // If there are entries in the answer section, and no errors, we are done!
if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR { if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
return Ok(response.clone()); return Ok(response);
} }
// We might also get a `NXDOMAIN` reply, which is the authoritative name servers // We might also get a `NXDOMAIN` reply, which is the authoritative name servers
// way of telling us that the name doesn't exist. // way of telling us that the name doesn't exist.
if response.header.rescode == ResultCode::NXDOMAIN { if response.header.rescode == ResultCode::NXDOMAIN {
return Ok(response.clone()); return Ok(response);
} }
// Otherwise, we'll try to find a new nameserver based on NS and a corresponding A // Otherwise, we'll try to find a new nameserver based on NS and a corresponding A
// record in the additional section. If this succeeds, we can switch name server // record in the additional section. If this succeeds, we can switch name server
// and retry the loop. // and retry the loop.
if let Some(new_ns) = response.get_resolved_ns(qname) { if let Some(new_ns) = response.get_resolved_ns(qname) {
ns = new_ns.clone(); ns = new_ns;
continue; continue;
} }
@ -278,7 +279,7 @@ fn recursive_lookup(qname: &str, qtype: QueryType) -> Result<DnsPacket> {
// we'll go with what the last server told us. // we'll go with what the last server told us.
let new_ns_name = match response.get_unresolved_ns(qname) { let new_ns_name = match response.get_unresolved_ns(qname) {
Some(x) => x, Some(x) => x,
None => return Ok(response.clone()), None => return Ok(response),
}; };
// Here we go down the rabbit hole by starting _another_ lookup sequence in the // Here we go down the rabbit hole by starting _another_ lookup sequence in the
@ -289,14 +290,22 @@ fn recursive_lookup(qname: &str, qtype: QueryType) -> Result<DnsPacket> {
// Finally, we pick a random ip from the result, and restart the loop. If no such // Finally, we pick a random ip from the result, and restart the loop. If no such
// record is available, we again return the last result we got. // record is available, we again return the last result we got.
if let Some(new_ns) = recursive_response.get_random_a() { if let Some(new_ns) = recursive_response.get_random_a() {
ns = new_ns.clone(); ns = new_ns;
} else { } else {
return Ok(response.clone()); return Ok(response);
} }
} }
} }
``` ```
This also requires a small change to the `lookup` function, as we need to
pass which server to use. We add a server parameter to the function signature,
and remove the hardcoded `server` variable we used in chapter 4:
```rust
fn lookup(qname: &str, qtype: QueryType, server: (Ipv4Addr, u16)) -> Result<DnsPacket> {
```
### Trying out recursive lookup ### Trying out recursive lookup
The only thing remaining is to change our `handle_query` function to use The only thing remaining is to change our `handle_query` function to use

View File

@ -290,7 +290,7 @@ impl DnsHeader {
)?; )?;
buffer.write_u8( buffer.write_u8(
(self.rescode.clone() as u8) (self.rescode as u8)
| ((self.checking_disabled as u8) << 4) | ((self.checking_disabled as u8) << 4)
| ((self.authed_data as u8) << 5) | ((self.authed_data as u8) << 5)
| ((self.z as u8) << 6) | ((self.z as u8) << 6)

View File

@ -303,7 +303,7 @@ impl DnsHeader {
)?; )?;
buffer.write_u8( buffer.write_u8(
(self.rescode.clone() as u8) (self.rescode as u8)
| ((self.checking_disabled as u8) << 4) | ((self.checking_disabled as u8) << 4)
| ((self.authed_data as u8) << 5) | ((self.authed_data as u8) << 5)
| ((self.z as u8) << 6) | ((self.z as u8) << 6)

View File

@ -303,7 +303,7 @@ impl DnsHeader {
)?; )?;
buffer.write_u8( buffer.write_u8(
(self.rescode.clone() as u8) (self.rescode as u8)
| ((self.checking_disabled as u8) << 4) | ((self.checking_disabled as u8) << 4)
| ((self.authed_data as u8) << 5) | ((self.authed_data as u8) << 5)
| ((self.z as u8) << 6) | ((self.z as u8) << 6)
@ -643,7 +643,7 @@ impl DnsPacket {
result.header.read(buffer)?; result.header.read(buffer)?;
for _ in 0..result.header.questions { for _ in 0..result.header.questions {
let mut question = DnsQuestion::new("".to_string(), QueryType::UNKNOWN(0)); let mut question = DnsQuestion::new(String::new(), QueryType::UNKNOWN(0));
question.read(buffer)?; question.read(buffer)?;
result.questions.push(question); result.questions.push(question);
} }
@ -728,7 +728,7 @@ fn handle_query(socket: &UdpSocket) -> Result<()> {
// Next, `DnsPacket::from_buffer` is used to parse the raw bytes into // Next, `DnsPacket::from_buffer` is used to parse the raw bytes into
// a `DnsPacket`. // a `DnsPacket`.
let request = DnsPacket::from_buffer(&mut req_buffer)?; let mut request = DnsPacket::from_buffer(&mut req_buffer)?;
// Create and initialize the response packet // Create and initialize the response packet
let mut packet = DnsPacket::new(); let mut packet = DnsPacket::new();
@ -737,15 +737,8 @@ fn handle_query(socket: &UdpSocket) -> Result<()> {
packet.header.recursion_available = true; packet.header.recursion_available = true;
packet.header.response = true; packet.header.response = true;
// Being mindful of how unreliable input data from arbitrary senders can be, we // In the normal case, exactly one question is present
// need make sure that a question is actually present. If not, we return `FORMERR` if let Some(question) = request.questions.pop() {
// to indicate that the sender made something wrong.
if request.questions.is_empty() {
packet.header.rescode = ResultCode::FORMERR;
}
// Usually a question will be present, though.
else {
let question = &request.questions[0];
println!("Received query: {:?}", question); println!("Received query: {:?}", question);
// Since all is set up and as expected, the query can be forwarded to the // Since all is set up and as expected, the query can be forwarded to the
@ -754,7 +747,7 @@ fn handle_query(socket: &UdpSocket) -> Result<()> {
// as much to the client. If rather everything goes as planned, the // as much to the client. If rather everything goes as planned, the
// question and response records as copied into our response packet. // question and response records as copied into our response packet.
if let Ok(result) = lookup(&question.name, question.qtype) { if let Ok(result) = lookup(&question.name, question.qtype) {
packet.questions.push(question.clone()); packet.questions.push(question);
packet.header.rescode = result.header.rescode; packet.header.rescode = result.header.rescode;
for rec in result.answers { for rec in result.answers {
@ -773,6 +766,12 @@ fn handle_query(socket: &UdpSocket) -> Result<()> {
packet.header.rescode = ResultCode::SERVFAIL; packet.header.rescode = ResultCode::SERVFAIL;
} }
} }
// Being mindful of how unreliable input data from arbitrary senders can be, we
// need make sure that a question is actually present. If not, we return `FORMERR`
// to indicate that the sender made something wrong.
else {
packet.header.rescode = ResultCode::FORMERR;
}
// The only thing remaining is to encode our response and send it off! // The only thing remaining is to encode our response and send it off!
let mut res_buffer = BytePacketBuffer::new(); let mut res_buffer = BytePacketBuffer::new();

View File

@ -303,7 +303,7 @@ impl DnsHeader {
)?; )?;
buffer.write_u8( buffer.write_u8(
(self.rescode.clone() as u8) (self.rescode as u8)
| ((self.checking_disabled as u8) << 4) | ((self.checking_disabled as u8) << 4)
| ((self.authed_data as u8) << 5) | ((self.authed_data as u8) << 5)
| ((self.z as u8) << 6) | ((self.z as u8) << 6)
@ -691,11 +691,11 @@ impl DnsPacket {
/// It's useful to be able to pick a random A record from a packet. When we /// It's useful to be able to pick a random A record from a packet. When we
/// get multiple IP's for a single name, it doesn't matter which one we /// get multiple IP's for a single name, it doesn't matter which one we
/// choose, so in those cases we can now pick one at random. /// choose, so in those cases we can now pick one at random.
pub fn get_random_a(&self) -> Option<String> { pub fn get_random_a(&self) -> Option<Ipv4Addr> {
self.answers self.answers
.iter() .iter()
.filter_map(|record| match record { .filter_map(|record| match record {
DnsRecord::A { ref addr, .. } => Some(addr.to_string()), DnsRecord::A { addr, .. } => Some(*addr),
_ => None, _ => None,
}) })
.next() .next()
@ -720,7 +720,7 @@ impl DnsPacket {
/// We'll use the fact that name servers often bundle the corresponding /// We'll use the fact that name servers often bundle the corresponding
/// A records when replying to an NS query to implement a function that /// A records when replying to an NS query to implement a function that
/// returns the actual IP for an NS record if possible. /// returns the actual IP for an NS record if possible.
pub fn get_resolved_ns(&self, qname: &str) -> Option<String> { pub fn get_resolved_ns(&self, qname: &str) -> Option<Ipv4Addr> {
// Get an iterator over the nameservers in the authorities section // Get an iterator over the nameservers in the authorities section
self.get_ns(qname) self.get_ns(qname)
// Now we need to look for a matching A record in the additional // Now we need to look for a matching A record in the additional
@ -736,7 +736,7 @@ impl DnsPacket {
_ => None, _ => None,
}) })
}) })
.map(|addr| addr.to_string()) .map(|addr| *addr)
// Finally, pick the first valid entry // Finally, pick the first valid entry
.next() .next()
} }
@ -745,16 +745,16 @@ impl DnsPacket {
/// be any A records in the additional section, and we'll have to perform *another* /// be any A records in the additional section, and we'll have to perform *another*
/// lookup in the midst. For this, we introduce a method for returning the host /// lookup in the midst. For this, we introduce a method for returning the host
/// name of an appropriate name server. /// name of an appropriate name server.
pub fn get_unresolved_ns(&self, qname: &str) -> Option<String> { pub fn get_unresolved_ns<'a>(&'a self, qname: &'a str) -> Option<&'a str> {
// Get an iterator over the nameservers in the authorities section // Get an iterator over the nameservers in the authorities section
self.get_ns(qname) self.get_ns(qname)
.map(|(_, host)| host.to_string()) .map(|(_, host)| host)
// Finally, pick the first valid entry // Finally, pick the first valid entry
.next() .next()
} }
} }
fn lookup(qname: &str, qtype: QueryType, server: (&str, u16)) -> Result<DnsPacket> { fn lookup(qname: &str, qtype: QueryType, server: (Ipv4Addr, u16)) -> Result<DnsPacket> {
let socket = UdpSocket::bind(("0.0.0.0", 43210))?; let socket = UdpSocket::bind(("0.0.0.0", 43210))?;
let mut packet = DnsPacket::new(); let mut packet = DnsPacket::new();
@ -778,34 +778,34 @@ fn lookup(qname: &str, qtype: QueryType, server: (&str, u16)) -> Result<DnsPacke
fn recursive_lookup(qname: &str, qtype: QueryType) -> Result<DnsPacket> { fn recursive_lookup(qname: &str, qtype: QueryType) -> Result<DnsPacket> {
// For now we're always starting with *a.root-servers.net*. // For now we're always starting with *a.root-servers.net*.
let mut ns = "198.41.0.4".to_string(); let mut ns = "198.41.0.4".parse::<Ipv4Addr>().unwrap();
// Since it might take an arbitrary number of steps, we enter an unbounded loop. // Since it might take an arbitrary number of steps, we enter an unbounded loop.
loop { loop {
println!("attempting lookup of {:?} {} with ns {}", qtype, qname, ns); println!("attempting lookup of {:?} {} with ns {}", qtype, qname, ns);
// The next step is to send the query to the active server. // The next step is to send the query to the active server.
let ns_copy = ns.clone(); let ns_copy = ns;
let server = (ns_copy.as_str(), 53); let server = (ns_copy, 53);
let response = lookup(qname, qtype.clone(), server)?; let response = lookup(qname, qtype, server)?;
// If there are entries in the answer section, and no errors, we are done! // If there are entries in the answer section, and no errors, we are done!
if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR { if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR {
return Ok(response.clone()); return Ok(response);
} }
// We might also get a `NXDOMAIN` reply, which is the authoritative name servers // We might also get a `NXDOMAIN` reply, which is the authoritative name servers
// way of telling us that the name doesn't exist. // way of telling us that the name doesn't exist.
if response.header.rescode == ResultCode::NXDOMAIN { if response.header.rescode == ResultCode::NXDOMAIN {
return Ok(response.clone()); return Ok(response);
} }
// Otherwise, we'll try to find a new nameserver based on NS and a corresponding A // Otherwise, we'll try to find a new nameserver based on NS and a corresponding A
// record in the additional section. If this succeeds, we can switch name server // record in the additional section. If this succeeds, we can switch name server
// and retry the loop. // and retry the loop.
if let Some(new_ns) = response.get_resolved_ns(qname) { if let Some(new_ns) = response.get_resolved_ns(qname) {
ns = new_ns.clone(); ns = new_ns;
continue; continue;
} }
@ -814,7 +814,7 @@ fn recursive_lookup(qname: &str, qtype: QueryType) -> Result<DnsPacket> {
// we'll go with what the last server told us. // we'll go with what the last server told us.
let new_ns_name = match response.get_unresolved_ns(qname) { let new_ns_name = match response.get_unresolved_ns(qname) {
Some(x) => x, Some(x) => x,
None => return Ok(response.clone()), None => return Ok(response),
}; };
// Here we go down the rabbit hole by starting _another_ lookup sequence in the // Here we go down the rabbit hole by starting _another_ lookup sequence in the
@ -825,9 +825,9 @@ fn recursive_lookup(qname: &str, qtype: QueryType) -> Result<DnsPacket> {
// Finally, we pick a random ip from the result, and restart the loop. If no such // Finally, we pick a random ip from the result, and restart the loop. If no such
// record is available, we again return the last result we got. // record is available, we again return the last result we got.
if let Some(new_ns) = recursive_response.get_random_a() { if let Some(new_ns) = recursive_response.get_random_a() {
ns = new_ns.clone(); ns = new_ns;
} else { } else {
return Ok(response.clone()); return Ok(response);
} }
} }
} }
@ -836,7 +836,7 @@ fn handle_query(socket: &UdpSocket) -> Result<()> {
let mut req_buffer = BytePacketBuffer::new(); let mut req_buffer = BytePacketBuffer::new();
let (_, src) = socket.recv_from(&mut req_buffer.buf)?; let (_, src) = socket.recv_from(&mut req_buffer.buf)?;
let request = DnsPacket::from_buffer(&mut req_buffer)?; let mut request = DnsPacket::from_buffer(&mut req_buffer)?;
let mut packet = DnsPacket::new(); let mut packet = DnsPacket::new();
packet.header.id = request.header.id; packet.header.id = request.header.id;
@ -844,10 +844,7 @@ fn handle_query(socket: &UdpSocket) -> Result<()> {
packet.header.recursion_available = true; packet.header.recursion_available = true;
packet.header.response = true; packet.header.response = true;
if request.questions.is_empty() { if let Some(question) = request.questions.pop() {
packet.header.rescode = ResultCode::FORMERR;
} else {
let question = &request.questions[0];
println!("Received query: {:?}", question); println!("Received query: {:?}", question);
if let Ok(result) = recursive_lookup(&question.name, question.qtype) { if let Ok(result) = recursive_lookup(&question.name, question.qtype) {
@ -869,6 +866,8 @@ fn handle_query(socket: &UdpSocket) -> Result<()> {
} else { } else {
packet.header.rescode = ResultCode::SERVFAIL; packet.header.rescode = ResultCode::SERVFAIL;
} }
} else {
packet.header.rescode = ResultCode::FORMERR;
} }
let mut res_buffer = BytePacketBuffer::new(); let mut res_buffer = BytePacketBuffer::new();