From 9648a9e57649fec825fef36e5070e55cade8591e Mon Sep 17 00:00:00 2001 From: nabil salah Date: Mon, 21 Oct 2024 16:25:26 +0300 Subject: [PATCH 01/12] feat: implement dnsserver Signed-off-by: nabil salah --- .gitignore | 4 + Cargo.lock | 7 + Cargo.toml | 6 + src/main.rs | 844 ++++++++++++++++++++++++++++++++++++++++++++++++++++ 4 files changed, 861 insertions(+) create mode 100644 .gitignore create mode 100644 Cargo.lock create mode 100644 Cargo.toml create mode 100644 src/main.rs diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..41a1821 --- /dev/null +++ b/.gitignore @@ -0,0 +1,4 @@ +/target +query_packet.txt +response_packet.txt +.vscode \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..1179058 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,7 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "dnsserver-nabil" +version = "0.1.0" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..ea6f41a --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,6 @@ +[package] +name = "dnsserver-nabil" +version = "0.1.0" +edition = "2021" + +[dependencies] diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..753d4b6 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,844 @@ +use std::{fs::File, io::{self, Error, Read}, net::{Ipv4Addr, Ipv6Addr, UdpSocket}}; + +pub struct BytePacketBuffer { + pub buf: [u8; 512], + pub pos: usize, +} + +impl BytePacketBuffer { + pub fn new () -> BytePacketBuffer { + return BytePacketBuffer { + buf: [0; 512], + pos: 0 + }; + } + fn pos(&self) -> usize{ + return self.pos; + } + fn step(&mut self, steps: usize) -> Result<(), io::Error> { + self.pos += steps; + Ok(()) + } + fn seek(&mut self, pos: usize) -> Result<(), io::Error> { + self.pos = pos; + + Ok(()) + } + fn read(&mut self) -> Result { + if self.pos >= 512 { + return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); + } + let res = self.buf[self.pos]; + self.pos += 1; + Ok(res) + } + fn get(& self, pos: usize) -> Result { + if pos >= 512 { + return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); + } + Ok(self.buf[pos]) + } + fn get_range(& self, start: usize, len: usize) -> Result<&[u8], io::Error> { + if start+len >= 512 { + return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); + } + Ok(&self.buf[start..start+len as usize]) + } + fn read_u16(&mut self) -> Result { + let res = ( ( self.read()? as u16) << 8) | ( self.read()? as u16); + Ok(res) + } + fn read_u32(&mut self) -> Result { + let res = ((self.read()? as u32) << 24) + | ((self.read()? as u32) << 16) + | ((self.read()? as u32) << 8) + | ((self.read()? as u32) << 0); + + Ok(res) + } + fn read_qname(&mut self, outstr: &mut String) -> Result<(), io::Error> { + + let mut pos = self.pos(); + let mut jumped = false; + let max_jumps = 5; + let mut jumps_performed = 0; + let mut delim = ""; + loop { + if jumps_performed > max_jumps { + return Err(Error::new(io::ErrorKind::UnexpectedEof, format!("Limit of {} jumps exceeded", max_jumps))); + } + + let len = self.get(pos)?; + + if (len & 0xC0) == 0xC0 { + if !jumped { + self.seek(pos+2)? + } + + let b2 = self.get(pos + 1)? as u16; + let offset = (((len as u16) ^ 0xC0) << 8) | b2; + pos = offset as usize; + + jumped = true; + jumps_performed += 1; + }else { + pos += 1; + + if len == 0 { + break; + } + outstr.push_str(delim); + let str_buffer = self.get_range(pos, len as usize)?; + outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase()); + + delim = "."; + pos += len as usize; + } + } + if !jumped { + self.seek(pos)?; + } + + Ok(()) + } + fn write(&mut self, val: u8) -> Result<(), io::Error> { + if self.pos >= 512 { + return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); + } + self.buf[self.pos] = val; + self.pos += 1; + Ok(()) + } + + fn write_u8(&mut self, val: u8) -> Result<(), io::Error> { + self.write(val)?; + + Ok(()) + } + + fn write_u16(&mut self, val: u16) -> Result<(), io::Error> { + self.write((val >> 8) as u8)?; + self.write((val & 0xFF) as u8)?; + + Ok(()) + } + + fn write_u32(&mut self, val: u32) -> Result<(), io::Error> { + self.write(((val >> 24) & 0xFF) as u8)?; + self.write(((val >> 16) & 0xFF) as u8)?; + self.write(((val >> 8) & 0xFF) as u8)?; + self.write(((val >> 0) & 0xFF) as u8)?; + + Ok(()) + } + fn write_qname(&mut self, qname: &str) -> Result<(), io::Error> { + for label in qname.split('.') { + let len = label.len(); + if len > 0x3f { + return Err(Error::new(io::ErrorKind::UnexpectedEof, "Single label exceeds 63 characters of length")); + } + + self.write_u8(len as u8)?; + for b in label.as_bytes() { + self.write_u8(*b)?; + } + } + + self.write_u8(0)?; + + Ok(()) + } + + fn set(&mut self, pos: usize, val: u8) -> Result<(), io::Error> { + self.buf[pos] = val; + + Ok(()) + } + + fn set_u16(&mut self, pos: usize, val: u16) -> Result<(), io::Error> { + self.set(pos, (val >> 8) as u8)?; + self.set(pos + 1, (val & 0xFF) as u8)?; + + Ok(()) + } +} + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum ResultCode { + NOERROR = 0, + FORMERR = 1, + SERVFAIL = 2, + NXDOMAIN = 3, + NOTIMP = 4, + REFUSED = 5, +} + +impl ResultCode { + pub fn from_num(num: u8) -> ResultCode { + match num { + 1 => ResultCode::FORMERR, + 2 => ResultCode::SERVFAIL, + 3 => ResultCode::NXDOMAIN, + 4 => ResultCode::NOTIMP, + 5 => ResultCode::REFUSED, + 0 | _ => ResultCode::NOERROR, + } + } +} + + +#[derive(Clone, Debug)] +pub struct DnsHeader { + pub id: u16, // 16 bits + + pub recursion_desired: bool, // 1 bit + pub truncated_message: bool, // 1 bit + pub authoritative_answer: bool, // 1 bit + pub opcode: u8, // 4 bits + pub response: bool, // 1 bit + + pub rescode: ResultCode, // 4 bits + pub checking_disabled: bool, // 1 bit + pub authed_data: bool, // 1 bit + pub z: bool, // 1 bit + pub recursion_available: bool, // 1 bit + + pub questions: u16, // 16 bits + pub answers: u16, // 16 bits + pub authoritative_entries: u16, // 16 bits + pub resource_entries: u16, // 16 bits +} + +impl DnsHeader { + pub fn new() -> DnsHeader { + DnsHeader { + id: 0, + + recursion_desired: false, + truncated_message: false, + authoritative_answer: false, + opcode: 0, + response: false, + + rescode: ResultCode::NOERROR, + checking_disabled: false, + authed_data: false, + z: false, + recursion_available: false, + + questions: 0, + answers: 0, + authoritative_entries: 0, + resource_entries: 0, + } + } + + pub fn read(&mut self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { + self.id = buffer.read_u16()?; + + let mut flags = buffer.read()?; + self.recursion_desired = (flags & (1 << 0)) > 0; + self.truncated_message = (flags & (1 << 1)) > 0; + self.authoritative_answer = (flags & (1 << 2)) > 0; + self.opcode = (flags >> 3) & 0x0F; + self.response = (flags & (1 << 7)) > 0; + + flags = buffer.read()?; + self.rescode = ResultCode::from_num(flags & 0x0F); + self.checking_disabled = (flags & (1 << 4)) > 0; + self.authed_data = (flags & (1 << 5)) > 0; + self.z = (flags & (1 << 6)) > 0; + self.recursion_available = (flags & (1 << 7)) > 0; + + self.questions = buffer.read_u16()?; + self.answers = buffer.read_u16()?; + self.authoritative_entries = buffer.read_u16()?; + self.resource_entries = buffer.read_u16()?; + + Ok(()) + } + pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { + buffer.write_u16(self.id)?; + + buffer.write_u8( + (self.recursion_desired as u8) + | ((self.truncated_message as u8) << 1) + | ((self.authoritative_answer as u8) << 2) + | (self.opcode << 3) + | ((self.response as u8) << 7) as u8, + )?; + + buffer.write_u8( + (self.rescode as u8) + | ((self.checking_disabled as u8) << 4) + | ((self.authed_data as u8) << 5) + | ((self.z as u8) << 6) + | ((self.recursion_available as u8) << 7), + )?; + + buffer.write_u16(self.questions)?; + buffer.write_u16(self.answers)?; + buffer.write_u16(self.authoritative_entries)?; + buffer.write_u16(self.resource_entries)?; + + Ok(()) + } +} + +#[derive(PartialEq, Eq, Debug, Clone, Hash, Copy)] +pub enum QueryType { + UNKNOWN(u16), + A, // 1 + NS, // 2 + CNAME, // 5 + MX, // 15 + AAAA, // 28 +} + +impl QueryType { + pub fn to_num(&self) -> u16 { + match *self { + QueryType::UNKNOWN(x) => x, + QueryType::A => 1, + QueryType::NS => 2, + QueryType::CNAME => 5, + QueryType::MX => 15, + QueryType::AAAA => 28, + } + } + + pub fn from_num(num: u16) -> QueryType { + match num { + 1 => QueryType::A, + 2 => QueryType::NS, + 5 => QueryType::CNAME, + 15 => QueryType::MX, + 28 => QueryType::AAAA, + _ => QueryType::UNKNOWN(num), + } + } +} + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DnsQuestion { + pub name: String, + pub qtype: QueryType, +} + +impl DnsQuestion { + pub fn new(name: String, qtype: QueryType) -> DnsQuestion { + DnsQuestion { + name: name, + qtype: qtype, + } + } + + pub fn read(&mut self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { + buffer.read_qname(&mut self.name)?; + self.qtype = QueryType::from_num(buffer.read_u16()?); // qtype + let _ = buffer.read_u16()?; // class + + Ok(()) + } + + pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { + buffer.write_qname(&self.name)?; + + let typenum = self.qtype.to_num(); + buffer.write_u16(typenum)?; + buffer.write_u16(1)?; + + Ok(()) + } +} + +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[allow(dead_code)] +pub enum DnsRecord { + UNKNOWN { + domain: String, + qtype: u16, + data_len: u16, + ttl: u32, + }, // 0 + A { + domain: String, + addr: Ipv4Addr, + ttl: u32, + }, // 1 + NS { + domain: String, + host: String, + ttl: u32, + }, // 2 + CNAME { + domain: String, + host: String, + ttl: u32, + }, // 5 + MX { + domain: String, + priority: u16, + host: String, + ttl: u32, + }, // 15 + AAAA { + domain: String, + addr: Ipv6Addr, + ttl: u32, + }, // 28 +} + +impl DnsRecord { + pub fn read(buffer: &mut BytePacketBuffer) -> Result { + let mut domain = String::new(); + buffer.read_qname(&mut domain)?; + + let qtype_num = buffer.read_u16()?; + let qtype = QueryType::from_num(qtype_num); + let _ = buffer.read_u16()?; + let ttl = buffer.read_u32()?; + let data_len = buffer.read_u16()?; + + match qtype { + QueryType::A => { + let raw_addr = buffer.read_u32()?; + let addr = Ipv4Addr::new( + ((raw_addr >> 24) & 0xFF) as u8, + ((raw_addr >> 16) & 0xFF) as u8, + ((raw_addr >> 8) & 0xFF) as u8, + ((raw_addr >> 0) & 0xFF) as u8, + ); + + Ok(DnsRecord::A { + domain: domain, + addr: addr, + ttl: ttl, + }) + } + QueryType::AAAA => { + let raw_addr1 = buffer.read_u32()?; + let raw_addr2 = buffer.read_u32()?; + let raw_addr3 = buffer.read_u32()?; + let raw_addr4 = buffer.read_u32()?; + let addr = Ipv6Addr::new( + ((raw_addr1 >> 16) & 0xFFFF) as u16, + ((raw_addr1 >> 0) & 0xFFFF) as u16, + ((raw_addr2 >> 16) & 0xFFFF) as u16, + ((raw_addr2 >> 0) & 0xFFFF) as u16, + ((raw_addr3 >> 16) & 0xFFFF) as u16, + ((raw_addr3 >> 0) & 0xFFFF) as u16, + ((raw_addr4 >> 16) & 0xFFFF) as u16, + ((raw_addr4 >> 0) & 0xFFFF) as u16, + ); + + Ok(DnsRecord::AAAA { + domain: domain, + addr: addr, + ttl: ttl, + }) + } + QueryType::NS => { + let mut ns = String::new(); + buffer.read_qname(&mut ns)?; + + Ok(DnsRecord::NS { + domain: domain, + host: ns, + ttl: ttl, + }) + } + QueryType::CNAME => { + let mut cname = String::new(); + buffer.read_qname(&mut cname)?; + + Ok(DnsRecord::CNAME { + domain: domain, + host: cname, + ttl: ttl, + }) + } + QueryType::MX => { + let priority = buffer.read_u16()?; + let mut mx = String::new(); + buffer.read_qname(&mut mx)?; + + Ok(DnsRecord::MX { + domain: domain, + priority: priority, + host: mx, + ttl: ttl, + }) + } + QueryType::UNKNOWN(_) => { + buffer.step(data_len as usize)?; + + Ok(DnsRecord::UNKNOWN { + domain: domain, + qtype: qtype_num, + data_len: data_len, + ttl: ttl, + }) + } + } + } + + pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result { + let start_pos = buffer.pos(); + + match *self { + DnsRecord::A { + ref domain, + ref addr, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::A.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + buffer.write_u16(4)?; + + let octets = addr.octets(); + buffer.write_u8(octets[0])?; + buffer.write_u8(octets[1])?; + buffer.write_u8(octets[2])?; + buffer.write_u8(octets[3])?; + } + DnsRecord::NS { + ref domain, + ref host, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::NS.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + + let pos = buffer.pos(); + buffer.write_u16(0)?; + + buffer.write_qname(host)?; + + let size = buffer.pos() - (pos + 2); + buffer.set_u16(pos, size as u16)?; + } + DnsRecord::CNAME { + ref domain, + ref host, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::CNAME.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + + let pos = buffer.pos(); + buffer.write_u16(0)?; + + buffer.write_qname(host)?; + + let size = buffer.pos() - (pos + 2); + buffer.set_u16(pos, size as u16)?; + } + DnsRecord::MX { + ref domain, + priority, + ref host, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::MX.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + + let pos = buffer.pos(); + buffer.write_u16(0)?; + + buffer.write_u16(priority)?; + buffer.write_qname(host)?; + + let size = buffer.pos() - (pos + 2); + buffer.set_u16(pos, size as u16)?; + } + DnsRecord::AAAA { + ref domain, + ref addr, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::AAAA.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + buffer.write_u16(16)?; + + for octet in &addr.segments() { + buffer.write_u16(*octet)?; + } + } + DnsRecord::UNKNOWN { .. } => { + println!("Skipping record: {:?}", self); + } + } + + Ok(buffer.pos() - start_pos) + } +} + + +#[derive(Clone, Debug)] +pub struct DnsPacket { + pub header: DnsHeader, + pub questions: Vec, + pub answers: Vec, + pub authorities: Vec, + pub resources: Vec, +} + +impl DnsPacket { + pub fn new() -> DnsPacket { + DnsPacket { + header: DnsHeader::new(), + questions: Vec::new(), + answers: Vec::new(), + authorities: Vec::new(), + resources: Vec::new(), + } + } + + pub fn from_buffer(buffer: &mut BytePacketBuffer) -> Result { + let mut result = DnsPacket::new(); + result.header.read(buffer)?; + + for _ in 0..result.header.questions { + let mut question = DnsQuestion::new("".to_string(), QueryType::UNKNOWN(0)); + question.read(buffer)?; + result.questions.push(question); + } + + for _ in 0..result.header.answers { + let rec = DnsRecord::read(buffer)?; + result.answers.push(rec); + } + for _ in 0..result.header.authoritative_entries { + let rec = DnsRecord::read(buffer)?; + result.authorities.push(rec); + } + for _ in 0..result.header.resource_entries { + let rec = DnsRecord::read(buffer)?; + result.resources.push(rec); + } + + Ok(result) + } + + pub fn write(&mut self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { + self.header.questions = self.questions.len() as u16; + self.header.answers = self.answers.len() as u16; + self.header.authoritative_entries = self.authorities.len() as u16; + self.header.resource_entries = self.resources.len() as u16; + + self.header.write(buffer)?; + + for question in &self.questions { + question.write(buffer)?; + } + for rec in &self.answers { + rec.write(buffer)?; + } + for rec in &self.authorities { + rec.write(buffer)?; + } + for rec in &self.resources { + rec.write(buffer)?; + } + + Ok(()) + } + + pub fn get_random_a(&self) -> Option { + self.answers + .iter() + .filter_map(|record| match record { + DnsRecord::A { addr, .. } => Some(*addr), + _ => None, + }) + .next() + } + + fn get_ns<'a>(&'a self, qname: &'a str) -> impl Iterator { + self.authorities + .iter() + .filter_map(|record| match record { + DnsRecord::NS { domain, host, .. } => Some((domain.as_str(), host.as_str())), + _ => None, + }) + .filter(move |(domain, _)| qname.ends_with(*domain)) + } + + + pub fn get_resolved_ns(&self, qname: &str) -> Option { + self.get_ns(qname) + .flat_map(|(_, host)| { + self.resources + .iter() + .filter_map(move |record| match record { + DnsRecord::A { domain, addr, .. } if domain == host => Some(addr), + _ => None, + }) + }) + .map(|addr| *addr) + .next() + } + + pub fn get_unresolved_ns<'a>(&'a self, qname: &'a str) -> Option<&'a str> { + self.get_ns(qname) + .map(|(_, host)| host) + .next() + } +} + + +fn recursive_lookup(qname: &str, qtype: QueryType) -> Result { + // For now we're always starting with *a.root-servers.net*. + let mut ns = "198.41.0.4".parse::().unwrap(); + + loop { + println!("attempting lookup of {:?} {} with ns {}", qtype, qname, ns); + + // send the query to the active server. + let ns_copy = ns; + + let server = (ns_copy, 53); + let response = lookup(qname, qtype, server)?; + + // If there are entries in the answer section, and no errors, we are done! + if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR { + return Ok(response); + } + + // We might also get a `NXDOMAIN` reply, which is the authoritative name servers + // way of telling us that the name doesn't exist. + if response.header.rescode == ResultCode::NXDOMAIN { + return Ok(response); + } + + // Otherwise, we'll try to find a new nameserver based on NS and a corresponding A + // record in the additional section. If this succeeds, we can switch name server + // and retry the loop. + if let Some(new_ns) = response.get_resolved_ns(qname) { + ns = new_ns; + + continue; + } + + // If not, we'll have to resolve the ip of a NS record. If no NS records exist, + // we'll go with what the last server told us. + let new_ns_name = match response.get_unresolved_ns(qname) { + Some(x) => x, + None => return Ok(response), + }; + + // Here we go down the rabbit hole by starting _another_ lookup sequence in the + // midst of our current one. Hopefully, this will give us the IP of an appropriate + // name server. + let recursive_response = recursive_lookup(&new_ns_name, QueryType::A)?; + + // Finally, we pick a random ip from the result, and restart the loop. If no such + // record is available, we again return the last result we got. + if let Some(new_ns) = recursive_response.get_random_a() { + ns = new_ns; + } else { + return Ok(response); + } + } +} + +fn lookup(qname: &str, qtype: QueryType, server: (Ipv4Addr, u16)) -> Result { + // Forward queries to Google's public DNS + //let server = ("8.8.8.8", 53); + + let socket = UdpSocket::bind(("0.0.0.0", 43210))?; + + let mut packet = DnsPacket::new(); + + packet.header.id = 6666; + packet.header.questions = 1; + packet.header.recursion_desired = true; + packet + .questions + .push(DnsQuestion::new(qname.to_string(), qtype)); + + let mut req_buffer = BytePacketBuffer::new(); + packet.write(&mut req_buffer)?; + socket.send_to(&req_buffer.buf[0..req_buffer.pos], server)?; + + let mut res_buffer = BytePacketBuffer::new(); + socket.recv_from(&mut res_buffer.buf)?; + + DnsPacket::from_buffer(&mut res_buffer) +} + +fn handle_query(socket: &UdpSocket) -> Result<(), io::Error> { + let mut req_buffer = BytePacketBuffer::new(); + + + let (_, src) = socket.recv_from(&mut req_buffer.buf)?; + + + let mut request = DnsPacket::from_buffer(&mut req_buffer)?; + + let mut packet = DnsPacket::new(); + packet.header.id = request.header.id; + packet.header.recursion_desired = true; + packet.header.recursion_available = true; + packet.header.response = true; + + if let Some(question) = request.questions.pop() { + println!("Received query: {:?}", question); + if let Ok(result) = recursive_lookup(&question.name, question.qtype) { + packet.questions.push(question.clone()); + packet.header.rescode = result.header.rescode; + + for rec in result.answers { + println!("Answer: {:?}", rec); + packet.answers.push(rec); + } + for rec in result.authorities { + println!("Authority: {:?}", rec); + packet.authorities.push(rec); + } + for rec in result.resources { + println!("Resource: {:?}", rec); + packet.resources.push(rec); + } + } else { + packet.header.rescode = ResultCode::SERVFAIL; + } + } else { + packet.header.rescode = ResultCode::FORMERR; + } + + let mut res_buffer = BytePacketBuffer::new(); + packet.write(&mut res_buffer)?; + + let len = res_buffer.pos(); + let data = res_buffer.get_range(0, len)?; + + socket.send_to(data, src)?; + + Ok(()) +} + +fn main() -> Result<(), io::Error> { + // Bind an UDP socket on port 2053 + let socket = UdpSocket::bind(("0.0.0.0", 2053))?; + + // For now, queries are handled sequentially, so an infinite loop for servicing + // requests is initiated. + loop { + match handle_query(&socket) { + Ok(_) => {}, + Err(e) => eprintln!("An error occurred: {}", e), + } + } +} From 73854f0786590b8e169b6737bdb32705507f1731 Mon Sep 17 00:00:00 2001 From: nabil salah Date: Wed, 23 Oct 2024 23:16:51 +0300 Subject: [PATCH 02/12] feat: add tests and documenting Signed-off-by: nabil salah --- README.md | 76 +++- makefile | 8 + src/byte_packet_buffer.rs | 374 +++++++++++++++++ src/dns_header.rs | 254 ++++++++++++ src/dns_packet.rs | 268 ++++++++++++ src/dns_question.rs | 91 +++++ src/dns_record.rs | 676 +++++++++++++++++++++++++++++++ src/lib.rs | 127 ++++++ src/main.rs | 831 +------------------------------------- src/query_types.rs | 33 ++ 10 files changed, 1908 insertions(+), 830 deletions(-) create mode 100644 makefile create mode 100644 src/byte_packet_buffer.rs create mode 100644 src/dns_header.rs create mode 100644 src/dns_packet.rs create mode 100644 src/dns_question.rs create mode 100644 src/dns_record.rs create mode 100644 src/lib.rs create mode 100644 src/query_types.rs diff --git a/README.md b/README.md index e3d77c5..e770c7b 100644 --- a/README.md +++ b/README.md @@ -1 +1,75 @@ -# dnsserver-nabil \ No newline at end of file +# dnsserver-nabil + +A simple and efficient DNS server implemented in Rust, designed to handle DNS queries and responses. This project includes documentation and tests to ensure functionality and reliability. + +## Features + +- **DNS Query Handling**: Supports various DNS query types, including A, AAAA, MX, NS, and CNAME. +- **Custom Implementation**: Fully implemented from scratch with a focus on clarity and maintainability. +- **Documentation**: Inline documentation explaining the code structure and logic. +- **Tests**: Comprehensive test suite covering all aspects of the DNS server functionality. + +## Getting Started + +### Prerequisites + +- Rust and Cargo installed on your machine. You can install Rust using [rustup](https://rustup.rs/). + +### Installation + +1. Clone the repository: + + ```bash + git clone https://github.com/codescalersinternships/dnsserver-nabil/tree/development + ``` + +2. Navigate to the project directory: + + ```bash + cd dnsserver-nabil + ``` + +3. Build the project: + + ```bash + cargo build + ``` + +### Running the Server + +To run the DNS server, use the following command: + +```bash +make run +``` + +### Running Tests + +To ensure everything is working correctly, you can run the test suite with: + +```bash +make test +``` + + + +## Usage + +Once the server is running, you can send DNS queries using any DNS client or command-line tools like `dig` or `nslookup`. For example: + +```bash +dig @localhost -p 2053 example.com +``` + +## Documentation + +This project is thoroughly documented. You can find detailed explanations of the code and its functionality within the source files. Additionally, you can generate and view the documentation using: + +```bash +make doc +``` + +## Contributing + +Contributions are welcome! If you have suggestions for improvements or new features, feel free to open an issue or submit a pull request. + diff --git a/makefile b/makefile new file mode 100644 index 0000000..7109e55 --- /dev/null +++ b/makefile @@ -0,0 +1,8 @@ +run: + cargo run + +test: + cargo test + +doc: + cargo doc --open \ No newline at end of file diff --git a/src/byte_packet_buffer.rs b/src/byte_packet_buffer.rs new file mode 100644 index 0000000..f3bcaeb --- /dev/null +++ b/src/byte_packet_buffer.rs @@ -0,0 +1,374 @@ +use std::io::{self, Error}; + +pub struct BytePacketBuffer { + pub buf: [u8; 512], + pub pos: usize, +} + +impl BytePacketBuffer { + /// This gives us a fresh buffer for holding the packet contents, and a + /// field for keeping track of where we are. + pub fn new () -> BytePacketBuffer { + return BytePacketBuffer { + buf: [0; 512], + pos: 0 + }; + } + + /// Current position within buffer + pub fn pos(&self) -> usize{ + return self.pos; + } + + /// Step the buffer position forward a specific number of steps + pub fn step(&mut self, steps: usize) -> Result<(), io::Error> { + if self.pos + steps >= 512 { + return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); + } + self.pos += steps; + Ok(()) + } + + /// Change the buffer position + pub fn seek(&mut self, pos: usize) -> Result<(), io::Error> { + if pos >= 512 { + return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); + } + self.pos = pos; + Ok(()) + } + + /// Read a single byte and move the position one step forward + pub fn read(&mut self) -> Result { + if self.pos >= 512 { + return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); + } + let res = self.buf[self.pos]; + self.pos += 1; + Ok(res) + } + + /// Get a single byte, without changing the buffer position + pub fn get(& self, pos: usize) -> Result { + if pos >= 512 { + return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); + } + Ok(self.buf[pos]) + } + + /// Get a range of bytes + pub fn get_range(& self, start: usize, len: usize) -> Result<&[u8], io::Error> { + if start+len > 512 { + return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); + } + Ok(&self.buf[start..(start+len) as usize]) + } + + /// Read two bytes, stepping two steps forward + pub fn read_u16(&mut self) -> Result { + let res = ( ( self.read()? as u16) << 8) | ( self.read()? as u16); + Ok(res) + } + + /// Read four bytes, stepping four steps forward + pub fn read_u32(&mut self) -> Result { + let res = ((self.read()? as u32) << 24) + | ((self.read()? as u32) << 16) + | ((self.read()? as u32) << 8) + | ((self.read()? as u32) << 0); + + Ok(res) + } + + /// Read a qname + /// + /// Will take something like [3]www[6]google[3]com[0] and append + /// www.google.com to outstr. + /// + /// also it handle jemps . + pub fn read_qname(&mut self, outstr: &mut String) -> Result<(), io::Error> { + + let mut pos = self.pos(); + let mut jumped = false; + let max_jumps = 5; + let mut jumps_performed = 0; + let mut delim = ""; + loop { + if jumps_performed > max_jumps { + return Err(Error::new(io::ErrorKind::UnexpectedEof, format!("Limit of {} jumps exceeded", max_jumps))); + } + + let len = self.get(pos)?; + + if (len & 0xC0) == 0xC0 { + if !jumped { + self.seek(pos+2)? + } + + let b2 = self.get(pos + 1)? as u16; + let offset = (((len as u16) ^ 0xC0) << 8) | b2; + pos = offset as usize; + + jumped = true; + jumps_performed += 1; + }else { + pos += 1; + + if len == 0 { + break; + } + outstr.push_str(delim); + let str_buffer = self.get_range(pos, len as usize)?; + outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase()); + + delim = "."; + pos += len as usize; + } + } + if !jumped { + self.seek(pos)?; + } + + Ok(()) + } + + /// Write a single byte and move the position one step forward + pub fn write(&mut self, val: u8) -> Result<(), io::Error> { + if self.pos >= 512 { + return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); + } + self.buf[self.pos] = val; + self.pos += 1; + Ok(()) + } + + /// Read a single byte and move the position one step forward + pub fn write_u8(&mut self, val: u8) -> Result<(), io::Error> { + self.write(val)?; + + Ok(()) + } + + /// Read a two bytes and move the position two steps forward + pub fn write_u16(&mut self, val: u16) -> Result<(), io::Error> { + self.write((val >> 8) as u8)?; + self.write((val & 0xFF) as u8)?; + + Ok(()) + } + + /// Write four bytes, stepping four steps forward + pub fn write_u32(&mut self, val: u32) -> Result<(), io::Error> { + self.write(((val >> 24) & 0xFF) as u8)?; + self.write(((val >> 16) & 0xFF) as u8)?; + self.write(((val >> 8) & 0xFF) as u8)?; + self.write(((val >> 0) & 0xFF) as u8)?; + + Ok(()) + } + + /// Write a qname + /// + /// Will take something like www.google.com + /// + /// dots are the separator used + pub fn write_qname(&mut self, qname: &str) -> Result<(), io::Error> { + for label in qname.split('.') { + let len = label.len(); + if len > 0x3f { + return Err(Error::new(io::ErrorKind::UnexpectedEof, "Single label exceeds 63 characters of length")); + } + + self.write_u8(len as u8)?; + for b in label.as_bytes() { + self.write_u8(*b)?; + } + } + + self.write_u8(0)?; + + Ok(()) + } + + /// Write a single byte, without changing the buffer position + pub fn set(&mut self, pos: usize, val: u8) -> Result<(), io::Error> { + if pos >= 512 { + return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); + } + self.buf[pos] = val; + + Ok(()) + } + + /// Get two bytes, without changing the buffer position + pub fn set_u16(&mut self, pos: usize, val: u16) -> Result<(), io::Error> { + self.set(pos, (val >> 8) as u8)?; + self.set(pos + 1, (val & 0xFF) as u8)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_initial_position() { + let buf = BytePacketBuffer::new(); + assert_eq!(buf.pos(), 0); + } + + #[test] + fn test_pos_step_seek_limit() { + let mut buf = BytePacketBuffer::new(); + assert!(buf.step(511).is_ok()); + assert_eq!(buf.pos(), 511); + assert!(buf.step(1).is_err()); + assert!(buf.seek(511).is_ok()); + assert!(buf.seek(512).is_err()); + } + + #[test] + fn test_read_write_limits() { + let mut buf = BytePacketBuffer::new(); + + // Test reading at limit + assert!(buf.seek(511).is_ok()); + assert!(buf.read().is_ok()); + assert!(buf.read().is_err()); + + // Test writing at limit + assert!(buf.seek(511).is_ok()); + assert!(buf.write(0xFF).is_ok()); + assert!(buf.write(0xFF).is_err()); + assert_eq!(buf.buf[511], 0xFF); + } + + #[test] + fn test_read_u16_u32() { + let mut buf = BytePacketBuffer::new(); + buf.buf[0] = 0x12; + buf.buf[1] = 0x34; + buf.buf[2] = 0x56; + buf.buf[3] = 0x78; + + assert_eq!(buf.read_u16().unwrap(), 0x1234); + assert!(buf.seek(0).is_ok()); + assert_eq!(buf.read_u32().unwrap(), 0x12345678); + } + + #[test] + fn test_set_get() { + let mut buf = BytePacketBuffer::new(); + assert!(buf.set(511, 0xAB).is_ok()); + assert_eq!(buf.get(511).unwrap(), 0xAB); + assert!(buf.set(512, 0xAB).is_err()); + } + + #[test] + fn test_read_qname_no_jump() { + let mut buf = BytePacketBuffer::new(); + buf.buf[0] = 3; + buf.buf[1] = b'w'; + buf.buf[2] = b'w'; + buf.buf[3] = b'w'; + buf.buf[4] = 6; + buf.buf[5] = b'g'; + buf.buf[6] = b'o'; + buf.buf[7] = b'o'; + buf.buf[8] = b'g'; + buf.buf[9] = b'l'; + buf.buf[10] = b'e'; + buf.buf[11] = 3; + buf.buf[12] = b'c'; + buf.buf[13] = b'o'; + buf.buf[14] = b'm'; + buf.buf[15] = 0; + + let mut outstr = String::new(); + assert!(buf.read_qname(&mut outstr).is_ok()); + assert_eq!(outstr, "www.google.com"); + } + + #[test] + fn test_read_qname_with_single_jump() { + let mut buf = BytePacketBuffer::new(); + buf.buf[20] = 3; + buf.buf[21] = b'w'; + buf.buf[22] = b'w'; + buf.buf[23] = b'w'; + buf.buf[24] = 6; + buf.buf[25] = b'g'; + buf.buf[26] = b'o'; + buf.buf[27] = b'o'; + buf.buf[28] = b'g'; + buf.buf[29] = b'l'; + buf.buf[30] = b'e'; + buf.buf[31] = 3; + buf.buf[32] = b'c'; + buf.buf[33] = b'o'; + buf.buf[34] = b'm'; + buf.buf[35] = 0; + + buf.buf[0] = 0xC0; + buf.buf[1] = 20; + + let mut outstr = String::new(); + assert!(buf.read_qname(&mut outstr).is_ok()); + assert_eq!(outstr, "www.google.com"); + } + + #[test] + fn test_read_qname_with_multiple_jumps() { + let mut buf = BytePacketBuffer::new(); + // "google.com" at position 20 + buf.buf[20] = 6; + buf.buf[21] = b'g'; + buf.buf[22] = b'o'; + buf.buf[23] = b'o'; + buf.buf[24] = b'g'; + buf.buf[25] = b'l'; + buf.buf[26] = b'e'; + buf.buf[27] = 3; + buf.buf[28] = b'c'; + buf.buf[29] = b'o'; + buf.buf[30] = b'm'; + buf.buf[31] = 0; + + // "www" at position 50 + buf.buf[50] = 3; + buf.buf[51] = b'w'; + buf.buf[52] = b'w'; + buf.buf[53] = b'w'; + buf.buf[54] = 0xC0; // Pointer to "google.com" + buf.buf[55] = 20; + + // Pointer to "www" at start + buf.buf[0] = 0xC0; + buf.buf[1] = 50; + + let mut outstr = String::new(); + assert!(buf.read_qname(&mut outstr).is_ok()); + assert_eq!(outstr, "www.google.com"); + } + + + #[test] + fn test_write_qname_success() { + let mut buf = BytePacketBuffer::new(); + assert!(buf.write_qname("www.google.com").is_ok()); + + let mut outstr = String::new(); + assert!(buf.seek(0).is_ok()); + assert!(buf.read_qname(&mut outstr).is_ok()); + assert_eq!(outstr, "www.google.com"); + } + + #[test] + fn test_write_qname_overflow() { + let mut buf = BytePacketBuffer::new(); + assert!(buf.seek(510).is_ok()); + assert!(buf.write_qname("a").is_err()); + } + +} \ No newline at end of file diff --git a/src/dns_header.rs b/src/dns_header.rs new file mode 100644 index 0000000..18f23d4 --- /dev/null +++ b/src/dns_header.rs @@ -0,0 +1,254 @@ +use std::io; + +use crate::byte_packet_buffer::BytePacketBuffer; + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum ResultCode { + NOERROR = 0, + FORMERR = 1, + SERVFAIL = 2, + NXDOMAIN = 3, + NOTIMP = 4, + REFUSED = 5, +} + +impl ResultCode { + pub fn from_num(num: u8) -> ResultCode { + match num { + 1 => ResultCode::FORMERR, + 2 => ResultCode::SERVFAIL, + 3 => ResultCode::NXDOMAIN, + 4 => ResultCode::NOTIMP, + 5 => ResultCode::REFUSED, + 0 | _ => ResultCode::NOERROR, + } + } +} + + + +#[derive(Clone, Debug)] +pub struct DnsHeader { + pub id: u16, // 16 bits + + pub recursion_desired: bool, // 1 bit + pub truncated_message: bool, // 1 bit + pub authoritative_answer: bool, // 1 bit + pub opcode: u8, // 4 bits + pub response: bool, // 1 bit + + pub rescode: ResultCode, // 4 bits + pub checking_disabled: bool, // 1 bit + pub authed_data: bool, // 1 bit + pub z: bool, // 1 bit + pub recursion_available: bool, // 1 bit + + pub questions: u16, // 16 bits + pub answers: u16, // 16 bits + pub authoritative_entries: u16, // 16 bits + pub resource_entries: u16, // 16 bits +} + +impl DnsHeader { + pub fn new() -> DnsHeader { + DnsHeader { + id: 0, + + recursion_desired: false, + truncated_message: false, + authoritative_answer: false, + opcode: 0, + response: false, + + rescode: ResultCode::NOERROR, + checking_disabled: false, + authed_data: false, + z: false, + recursion_available: false, + + questions: 0, + answers: 0, + authoritative_entries: 0, + resource_entries: 0, + } + } + /// Read Dns packet header. + pub fn read(&mut self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { + self.id = buffer.read_u16()?; + + let mut flags = buffer.read()?; + self.recursion_desired = (flags & (1 << 0)) > 0; + self.truncated_message = (flags & (1 << 1)) > 0; + self.authoritative_answer = (flags & (1 << 2)) > 0; + self.opcode = (flags >> 3) & 0x0F; + self.response = (flags & (1 << 7)) > 0; + + flags = buffer.read()?; + self.rescode = ResultCode::from_num(flags & 0x0F); + self.checking_disabled = (flags & (1 << 4)) > 0; + self.authed_data = (flags & (1 << 5)) > 0; + self.z = (flags & (1 << 6)) > 0; + self.recursion_available = (flags & (1 << 7)) > 0; + + self.questions = buffer.read_u16()?; + self.answers = buffer.read_u16()?; + self.authoritative_entries = buffer.read_u16()?; + self.resource_entries = buffer.read_u16()?; + + Ok(()) + } + /// Write Dns packet header as bytes to buffer. + pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { + buffer.write_u16(self.id)?; + + buffer.write_u8( + (self.recursion_desired as u8) + | ((self.truncated_message as u8) << 1) + | ((self.authoritative_answer as u8) << 2) + | (self.opcode << 3) + | ((self.response as u8) << 7) as u8, + )?; + + buffer.write_u8( + (self.rescode as u8) + | ((self.checking_disabled as u8) << 4) + | ((self.authed_data as u8) << 5) + | ((self.z as u8) << 6) + | ((self.recursion_available as u8) << 7), + )?; + + buffer.write_u16(self.questions)?; + buffer.write_u16(self.answers)?; + buffer.write_u16(self.authoritative_entries)?; + buffer.write_u16(self.resource_entries)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_read_header_from_buffer() { + let mut buf = BytePacketBuffer::new(); + + buf.buf[0] = 0x12; + buf.buf[1] = 0x34; + + buf.buf[2] = 0b10010101; // flags: recursion_desired=1, truncated_message=0, authoritative_answer=1, opcode=2, response=1 + buf.buf[3] = 0b11100101; // flags: rescode=5, checking_disabled=0, authed_data=1, z=1, recursion_available=1 + + buf.buf[4] = 0x00; + buf.buf[5] = 0x01; // questions = 1 + buf.buf[6] = 0x00; + buf.buf[7] = 0x02; // answers = 2 + buf.buf[8] = 0x00; + buf.buf[9] = 0x03; // authoritative_entries = 3 + buf.buf[10] = 0x00; + buf.buf[11] = 0x04; // resource_entries = 4 + + let mut header = DnsHeader::new(); + assert!(header.read(&mut buf).is_ok()); + + assert_eq!(header.id, 0x1234); + assert_eq!(header.recursion_desired, true); + assert_eq!(header.truncated_message, false); + assert_eq!(header.authoritative_answer, true); + assert_eq!(header.opcode, 2); + assert_eq!(header.response, true); + assert_eq!(header.rescode, ResultCode::REFUSED); + assert_eq!(header.checking_disabled, false); + assert_eq!(header.authed_data, true); + assert_eq!(header.z, true); + assert_eq!(header.recursion_available, true); + assert_eq!(header.questions, 1); + assert_eq!(header.answers, 2); + assert_eq!(header.authoritative_entries, 3); + assert_eq!(header.resource_entries, 4); + } + + #[test] + fn test_write_header_to_buffer() { + let mut header = DnsHeader::new(); + header.id = 0x1234; + header.recursion_desired = true; + header.truncated_message = false; + header.authoritative_answer = true; + header.opcode = 2; + header.response = true; + header.rescode = ResultCode::REFUSED; + header.checking_disabled = false; + header.authed_data = true; + header.z = true; + header.recursion_available = true; + header.questions = 1; + header.answers = 2; + header.authoritative_entries = 3; + header.resource_entries = 4; + + let mut buf = BytePacketBuffer::new(); + assert!(header.write(&mut buf).is_ok()); + + assert_eq!(buf.buf[0], 0x12); + assert_eq!(buf.buf[1], 0x34); // id = 0x1234 + + assert_eq!(buf.buf[2], 0b10010101); + assert_eq!(buf.buf[3], 0b11100101); + + assert_eq!(buf.buf[4], 0x00); + assert_eq!(buf.buf[5], 0x01); // questions = 1 + assert_eq!(buf.buf[6], 0x00); + assert_eq!(buf.buf[7], 0x02); // answers = 2 + assert_eq!(buf.buf[8], 0x00); + assert_eq!(buf.buf[9], 0x03); // authoritative_entries = 3 + assert_eq!(buf.buf[10], 0x00); + assert_eq!(buf.buf[11], 0x04); // resource_entries = 4 + } + + #[test] + fn test_read_write() { + let mut header = DnsHeader::new(); + header.id = 0x4321; + header.recursion_desired = true; + header.truncated_message = true; + header.authoritative_answer = true; + header.opcode = 3; + header.response = false; + header.rescode = ResultCode::REFUSED; + header.checking_disabled = true; + header.authed_data = true; + header.z = false; + header.recursion_available = true; + header.questions = 10; + header.answers = 20; + header.authoritative_entries = 30; + header.resource_entries = 40; + + let mut buf = BytePacketBuffer::new(); + assert!(header.write(&mut buf).is_ok()); + + let mut read_header = DnsHeader::new(); + assert!(buf.seek(0).is_ok()); + assert!(read_header.read(&mut buf).is_ok()); + println!("{read_header:?}"); + + + assert_eq!(header.id, read_header.id); + assert_eq!(header.recursion_desired, read_header.recursion_desired); + assert_eq!(header.truncated_message, read_header.truncated_message); + assert_eq!(header.authoritative_answer, read_header.authoritative_answer); + assert_eq!(header.opcode, read_header.opcode); + assert_eq!(header.response, read_header.response); + assert_eq!(header.rescode, read_header.rescode); + assert_eq!(header.checking_disabled, read_header.checking_disabled); + assert_eq!(header.authed_data, read_header.authed_data); + assert_eq!(header.z, read_header.z); + assert_eq!(header.recursion_available, read_header.recursion_available); + assert_eq!(header.questions, read_header.questions); + assert_eq!(header.answers, read_header.answers); + assert_eq!(header.authoritative_entries, read_header.authoritative_entries); + assert_eq!(header.resource_entries, read_header.resource_entries); + } +} \ No newline at end of file diff --git a/src/dns_packet.rs b/src/dns_packet.rs new file mode 100644 index 0000000..679b2c9 --- /dev/null +++ b/src/dns_packet.rs @@ -0,0 +1,268 @@ +use std::{io, net::Ipv4Addr}; + +use crate::{byte_packet_buffer::BytePacketBuffer, dns_header::DnsHeader, dns_question::DnsQuestion, dns_record::DnsRecord, query_types::QueryType}; + +#[derive(Clone, Debug)] +pub struct DnsPacket { + pub header: DnsHeader, + pub questions: Vec, + pub answers: Vec, + pub authorities: Vec, + pub resources: Vec, +} + +impl DnsPacket { + pub fn new() -> DnsPacket { + DnsPacket { + header: DnsHeader::new(), + questions: Vec::new(), + answers: Vec::new(), + authorities: Vec::new(), + resources: Vec::new(), + } + } + + /// Read complete Dns packet from BytePacketBuffer + pub fn from_buffer(buffer: &mut BytePacketBuffer) -> Result { + let mut result = DnsPacket::new(); + result.header.read(buffer)?; + + for _ in 0..result.header.questions { + let mut question = DnsQuestion::new("".to_string(), QueryType::UNKNOWN(0)); + question.read(buffer)?; + result.questions.push(question); + } + + for _ in 0..result.header.answers { + let rec = DnsRecord::read(buffer)?; + result.answers.push(rec); + } + for _ in 0..result.header.authoritative_entries { + let rec = DnsRecord::read(buffer)?; + result.authorities.push(rec); + } + for _ in 0..result.header.resource_entries { + let rec = DnsRecord::read(buffer)?; + result.resources.push(rec); + } + + Ok(result) + } + + /// Write complete Dns packet to BytePacketBuffer + pub fn write(&mut self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { + self.header.questions = self.questions.len() as u16; + self.header.answers = self.answers.len() as u16; + self.header.authoritative_entries = self.authorities.len() as u16; + self.header.resource_entries = self.resources.len() as u16; + + self.header.write(buffer)?; + + for question in &self.questions { + question.write(buffer)?; + } + for rec in &self.answers { + rec.write(buffer)?; + } + for rec in &self.authorities { + rec.write(buffer)?; + } + for rec in &self.resources { + rec.write(buffer)?; + } + + Ok(()) + } + + /// Get Random A to be able to pick a random A record from a packet. When we + /// get multiple IP's for a single name, it doesn't matter which one we + /// choose, so in those cases we can now pick one at random. + pub fn get_random_a(&self) -> Option { + self.answers + .iter() + .filter_map(|record| match record { + DnsRecord::A { addr, .. } => Some(*addr), + _ => None, + }) + .next() + } + + /// Get NS helper function which returns an iterator over all name servers in + /// the authorities section, represented as (domain, host) tuples + fn get_ns<'a>(&'a self, qname: &'a str) -> impl Iterator { + self.authorities + .iter() + .filter_map(|record| match record { + DnsRecord::NS { domain, host, .. } => Some((domain.as_str(), host.as_str())), + _ => None, + }) + .filter(move |(domain, _)| qname.ends_with(*domain)) + } + + /// Get Resolved NS + /// as We'll use the fact that name servers often bundle the corresponding + /// A records when replying to an NS query to implement a function that + /// returns the actual IP for an NS record if possible. + pub fn get_resolved_ns(&self, qname: &str) -> Option { + self.get_ns(qname) + .flat_map(|(_, host)| { + self.resources + .iter() + .filter_map(move |record| match record { + DnsRecord::A { domain, addr, .. } if domain == host => Some(addr), + _ => None, + }) + }) + .map(|addr| *addr) + .next() + } + + /// Get Unresolved NS a method for returning the host + /// name of an appropriate name server. + pub fn get_unresolved_ns<'a>(&'a self, qname: &'a str) -> Option<&'a str> { + self.get_ns(qname) + .map(|(_, host)| host) + .next() + } +} + + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_write_and_read_dns_packet() { + let mut buffer = BytePacketBuffer::new(); + let mut packet = DnsPacket::new(); + + // Setup a DNS question + let question = DnsQuestion::new("google.com".to_string(), QueryType::A); + packet.questions.push(question); + + // Setup a DNS answer + let answer = DnsRecord::A { + domain: "google.com".to_string(), + addr: Ipv4Addr::new(216, 58, 211, 142), + ttl: 293, + }; + packet.answers.push(answer); + + // Setup a DNS authority + let authority = DnsRecord::NS { + domain: "google.com".to_string(), + host: "ns1.google.com".to_string(), + ttl: 293, + }; + packet.authorities.push(authority); + + // Setup a DNS resource + let resource = DnsRecord::A { + domain: "ns1.google.com".to_string(), + addr: Ipv4Addr::new(8, 8, 8, 8), + ttl: 293, + }; + packet.resources.push(resource); + + packet.write(&mut buffer).expect("Failed to write DNS packet"); + + assert!(buffer.seek(0).is_ok()); + + let read_packet = DnsPacket::from_buffer(&mut buffer).expect("Failed to read DNS packet"); + + // Validate the read packet + assert_eq!(read_packet.questions.len(), 1); + assert_eq!(read_packet.answers.len(), 1); + assert_eq!(read_packet.authorities.len(), 1); + assert_eq!(read_packet.resources.len(), 1); + + // Validate the question + match &read_packet.questions[0] { + DnsQuestion { name, qtype, .. } => { + assert_eq!(name, "google.com"); + assert_eq!(qtype, &QueryType::A); + } + } + + // Validate the answer + match &read_packet.answers[0] { + DnsRecord::A { domain, addr, ttl } => { + assert_eq!(domain, "google.com"); + assert_eq!(addr, &Ipv4Addr::new(216, 58, 211, 142)); + assert_eq!(*ttl, 293); + } + _ => panic!("Expected A record"), + } + + // Validate the authority + match &read_packet.authorities[0] { + DnsRecord::NS { domain, host, ttl } => { + assert_eq!(domain, "google.com"); + assert_eq!(host, "ns1.google.com"); + assert_eq!(*ttl, 293); + } + _ => panic!("Expected NS record"), + } + + // Validate the resource + match &read_packet.resources[0] { + DnsRecord::A { domain, addr, ttl } => { + assert_eq!(domain, "ns1.google.com"); + assert_eq!(addr, &Ipv4Addr::new(8, 8, 8, 8)); + assert_eq!(*ttl, 293); + } + _ => panic!("Expected A record"), + } + } + + #[test] + fn test_get_random_a() { + let mut packet = DnsPacket::new(); + packet.answers.push(DnsRecord::A { + domain: "google.com".to_string(), + addr: Ipv4Addr::new(216, 58, 211, 142), + ttl: 293, + }); + packet.answers.push(DnsRecord::A { + domain: "example.com".to_string(), + addr: Ipv4Addr::new(93, 184, 216, 34), + ttl: 293, + }); + + let random_a = packet.get_random_a(); + assert!(random_a.is_some()); + let random_a = random_a.unwrap(); + assert!(random_a == Ipv4Addr::new(216, 58, 211, 142) || random_a == Ipv4Addr::new(93, 184, 216, 34)); + } + + #[test] + fn test_get_resolved_ns() { + let mut packet = DnsPacket::new(); + packet.authorities.push(DnsRecord::NS { + domain: "google.com".to_string(), + host: "ns1.google.com".to_string(), + ttl: 293, + }); + packet.resources.push(DnsRecord::A { + domain: "ns1.google.com".to_string(), + addr: Ipv4Addr::new(8, 8, 8, 8), + ttl: 293, + }); + + let resolved_ns = packet.get_resolved_ns("google.com"); + assert_eq!(resolved_ns, Some(Ipv4Addr::new(8, 8, 8, 8))); + } + + #[test] + fn test_get_unresolved_ns() { + let mut packet = DnsPacket::new(); + packet.authorities.push(DnsRecord::NS { + domain: "google.com".to_string(), + host: "ns1.google.com".to_string(), + ttl: 293, + }); + + let unresolved_ns = packet.get_unresolved_ns("google.com"); + assert_eq!(unresolved_ns, Some("ns1.google.com")); + } +} diff --git a/src/dns_question.rs b/src/dns_question.rs new file mode 100644 index 0000000..00b719a --- /dev/null +++ b/src/dns_question.rs @@ -0,0 +1,91 @@ +use std::io; + +use crate::{byte_packet_buffer::BytePacketBuffer, query_types::QueryType}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub struct DnsQuestion { + pub name: String, + pub qtype: QueryType, +} + +impl DnsQuestion { + pub fn new(name: String, qtype: QueryType) -> DnsQuestion { + DnsQuestion { + name: name, + qtype: qtype, + } + } + /// Read Dns question. + pub fn read(&mut self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { + buffer.read_qname(&mut self.name)?; + self.qtype = QueryType::from_num(buffer.read_u16()?); // qtype + let _ = buffer.read_u16()?; // class + Ok(()) + } + + /// Write Dns question as bytes to buffer. + pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { + buffer.write_qname(&self.name)?; + + let typenum = self.qtype.to_num(); + buffer.write_u16(typenum)?; + buffer.write_u16(1)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_read_question_from_buffer() { + let mut buf = BytePacketBuffer::new(); + + buf.buf[0] = 6; + buf.buf[1..7].copy_from_slice(b"google"); + buf.buf[7] = 3; + buf.buf[8..11].copy_from_slice(b"com"); + buf.buf[11] = 0; + buf.buf[12] = 0x00; + buf.buf[13] = 0x01; // qtype A + + let mut question = DnsQuestion::new("".to_string(), QueryType::A); + assert!(question.read(&mut buf).is_ok()); + + assert_eq!(question.name, "google.com"); + assert_eq!(question.qtype, QueryType::A); + } + + #[test] + fn test_write_question_to_buffer() { + let question = DnsQuestion::new("google.com".to_string(), QueryType::A); + + let mut buf = BytePacketBuffer::new(); + assert!(question.write(&mut buf).is_ok()); + + assert_eq!(buf.buf[0], 6); + assert_eq!(&buf.buf[1..7], b"google"); + assert_eq!(buf.buf[7], 3); + assert_eq!(&buf.buf[8..11], b"com"); + assert_eq!(buf.buf[11], 0); + assert_eq!(buf.buf[12], 0x00); + assert_eq!(buf.buf[13], 0x01); // qtype A + } + + #[test] + fn test_read_write() { + let question = DnsQuestion::new("google.com".to_string(), QueryType::A); + + let mut buf = BytePacketBuffer::new(); + assert!(question.write(&mut buf).is_ok()); + + let mut read_question = DnsQuestion::new("".to_string(), QueryType::A); + assert!(buf.seek(0).is_ok()); + assert!(read_question.read(&mut buf).is_ok()); + + assert_eq!(question.name, read_question.name); + assert_eq!(question.qtype, read_question.qtype); + } +} \ No newline at end of file diff --git a/src/dns_record.rs b/src/dns_record.rs new file mode 100644 index 0000000..dfec68c --- /dev/null +++ b/src/dns_record.rs @@ -0,0 +1,676 @@ +use std::{io, net::{Ipv4Addr, Ipv6Addr}}; + +use crate::{byte_packet_buffer::BytePacketBuffer, QueryType}; + +#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +pub enum DnsRecord { + UNKNOWN { + domain: String, + qtype: u16, + data_len: u16, + ttl: u32, + }, // 0 + A { + domain: String, + addr: Ipv4Addr, + ttl: u32, + }, // 1 + NS { + domain: String, + host: String, + ttl: u32, + }, // 2 + CNAME { + domain: String, + host: String, + ttl: u32, + }, // 5 + MX { + domain: String, + priority: u16, + host: String, + ttl: u32, + }, // 15 + AAAA { + domain: String, + addr: Ipv6Addr, + ttl: u32, + }, // 28 +} + +impl DnsRecord { + /// Read Dns record (Answer Section - Authority Section - Additional Section) + pub fn read(buffer: &mut BytePacketBuffer) -> Result { + let mut domain = String::new(); + buffer.read_qname(&mut domain)?; + + let qtype_num = buffer.read_u16()?; + let qtype = QueryType::from_num(qtype_num); + let _ = buffer.read_u16()?; + let ttl = buffer.read_u32()?; + let data_len = buffer.read_u16()?; + + match qtype { + QueryType::A => { + let raw_addr = buffer.read_u32()?; + let addr = Ipv4Addr::new( + ((raw_addr >> 24) & 0xFF) as u8, + ((raw_addr >> 16) & 0xFF) as u8, + ((raw_addr >> 8) & 0xFF) as u8, + ((raw_addr >> 0) & 0xFF) as u8, + ); + + Ok(DnsRecord::A { + domain: domain, + addr: addr, + ttl: ttl, + }) + } + QueryType::AAAA => { + let raw_addr1 = buffer.read_u32()?; + let raw_addr2 = buffer.read_u32()?; + let raw_addr3 = buffer.read_u32()?; + let raw_addr4 = buffer.read_u32()?; + let addr = Ipv6Addr::new( + ((raw_addr1 >> 16) & 0xFFFF) as u16, + ((raw_addr1 >> 0) & 0xFFFF) as u16, + ((raw_addr2 >> 16) & 0xFFFF) as u16, + ((raw_addr2 >> 0) & 0xFFFF) as u16, + ((raw_addr3 >> 16) & 0xFFFF) as u16, + ((raw_addr3 >> 0) & 0xFFFF) as u16, + ((raw_addr4 >> 16) & 0xFFFF) as u16, + ((raw_addr4 >> 0) & 0xFFFF) as u16, + ); + + Ok(DnsRecord::AAAA { + domain: domain, + addr: addr, + ttl: ttl, + }) + } + QueryType::NS => { + let mut ns = String::new(); + buffer.read_qname(&mut ns)?; + + Ok(DnsRecord::NS { + domain: domain, + host: ns, + ttl: ttl, + }) + } + QueryType::CNAME => { + let mut cname = String::new(); + buffer.read_qname(&mut cname)?; + + Ok(DnsRecord::CNAME { + domain: domain, + host: cname, + ttl: ttl, + }) + } + QueryType::MX => { + let priority = buffer.read_u16()?; + let mut mx = String::new(); + buffer.read_qname(&mut mx)?; + + Ok(DnsRecord::MX { + domain: domain, + priority: priority, + host: mx, + ttl: ttl, + }) + } + QueryType::UNKNOWN(_) => { + buffer.step(data_len as usize)?; + + Ok(DnsRecord::UNKNOWN { + domain: domain, + qtype: qtype_num, + data_len: data_len, + ttl: ttl, + }) + } + } + } + + /// Write Dns record (Answer Section - Authority Section - Additional Section) + pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result { + let start_pos = buffer.pos(); + + match *self { + DnsRecord::A { + ref domain, + ref addr, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::A.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + buffer.write_u16(4)?; + + let octets = addr.octets(); + buffer.write_u8(octets[0])?; + buffer.write_u8(octets[1])?; + buffer.write_u8(octets[2])?; + buffer.write_u8(octets[3])?; + } + DnsRecord::NS { + ref domain, + ref host, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::NS.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + + let pos = buffer.pos(); + buffer.write_u16(0)?; + + buffer.write_qname(host)?; + + let size = buffer.pos() - (pos + 2); + buffer.set_u16(pos, size as u16)?; + } + DnsRecord::CNAME { + ref domain, + ref host, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::CNAME.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + + let pos = buffer.pos(); + buffer.write_u16(0)?; + + buffer.write_qname(host)?; + + let size = buffer.pos() - (pos + 2); + buffer.set_u16(pos, size as u16)?; + } + DnsRecord::MX { + ref domain, + priority, + ref host, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::MX.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + + let pos = buffer.pos(); + buffer.write_u16(0)?; + + buffer.write_u16(priority)?; + buffer.write_qname(host)?; + + let size = buffer.pos() - (pos + 2); + buffer.set_u16(pos, size as u16)?; + } + DnsRecord::AAAA { + ref domain, + ref addr, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::AAAA.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + buffer.write_u16(16)?; + + for octet in &addr.segments() { + buffer.write_u16(*octet)?; + } + } + DnsRecord::UNKNOWN { .. } => { + println!("Skipping record: {:?}", self); + } + } + + Ok(buffer.pos() - start_pos) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_read_a_record() { + let mut buf = BytePacketBuffer::new(); + + buf.buf[0] = 6; + buf.buf[1..7].copy_from_slice(b"google"); + buf.buf[7] = 3; + buf.buf[8..11].copy_from_slice(b"com"); + buf.buf[11] = 0; + buf.buf[12] = 0x00; + buf.buf[13] = 0x01; // qtype A + buf.buf[14] = 0x00; + buf.buf[15] = 0x01; // class IN + buf.buf[16] = 0x00; + buf.buf[17] = 0x00; + buf.buf[18] = 0x01; + buf.buf[19] = 0x25; // ttl = 293 + buf.buf[20] = 0x00; + buf.buf[21] = 0x04; // data_len = 4 (IPv4 length) + buf.buf[22] = 0xd8; // 216.58.211.142 + buf.buf[23] = 0x3a; + buf.buf[24] = 0xd3; + buf.buf[25] = 0x8e; + + let record = DnsRecord::read(&mut buf).expect("Failed to read A record"); + + match record { + DnsRecord::A { domain, addr, ttl } => { + assert_eq!(domain, "google.com"); + assert_eq!(addr, Ipv4Addr::new(216, 58, 211, 142)); + assert_eq!(ttl, 293); + } + _ => panic!("Record is not of type A"), + } + } + + #[test] + fn test_write_a_record() { + let record = DnsRecord::A { + domain: "google.com".to_string(), + addr: Ipv4Addr::new(216, 58, 211, 142), + ttl: 293, + }; + + let mut buf = BytePacketBuffer::new(); + record.write(&mut buf).expect("Failed to write A record"); + + assert_eq!(buf.buf[0], 6); + assert_eq!(&buf.buf[1..7], b"google"); + assert_eq!(buf.buf[7], 3); + assert_eq!(&buf.buf[8..11], b"com"); + assert_eq!(buf.buf[11], 0); + assert_eq!(buf.buf[12], 0x00); + assert_eq!(buf.buf[13], 0x01); // qtype A + assert_eq!(buf.buf[14], 0x00); + assert_eq!(buf.buf[15], 0x01); // class IN + assert_eq!(buf.buf[16], 0x00); + assert_eq!(buf.buf[17], 0x00); + assert_eq!(buf.buf[18], 0x01); + assert_eq!(buf.buf[19], 0x25); // ttl = 293 + assert_eq!(buf.buf[20], 0x00); + assert_eq!(buf.buf[21], 0x04); // data_len = 4 + assert_eq!(&buf.buf[22..26], &[216, 58, 211, 142]); + } + + #[test] + fn test_a_record() { + let record = DnsRecord::A { + domain: "google.com".to_string(), + addr: Ipv4Addr::new(216, 58, 211, 142), + ttl: 293, + }; + + let mut buf = BytePacketBuffer::new(); + record.write(&mut buf).expect("Failed to write A record"); + assert!(buf.seek(0).is_ok()); + let read_record = DnsRecord::read(&mut buf).expect("Failed to read A record"); + + assert_eq!(record, read_record); + } + + + #[test] + fn test_read_aaaa_record() { + let mut buf = BytePacketBuffer::new(); + + buf.buf[0] = 6; + buf.buf[1..7].copy_from_slice(b"google"); + buf.buf[7] = 3; + buf.buf[8..11].copy_from_slice(b"com"); + buf.buf[11] = 0; + buf.buf[12] = 0x00; + buf.buf[13] = 0x1C; // qtype AAAA + buf.buf[14] = 0x00; + buf.buf[15] = 0x01; // class IN + buf.buf[16] = 0x00; + buf.buf[17] = 0x00; + buf.buf[18] = 0x01; + buf.buf[19] = 0x25; // ttl = 293 + buf.buf[20] = 0x00; + buf.buf[21] = 0x10; // data_len = 16 (IPv6 length) + buf.buf[22..38].copy_from_slice(&[ + 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x10, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + ]); + + let record = DnsRecord::read(&mut buf).expect("Failed to read AAAA record"); + + match record { + DnsRecord::AAAA { domain, addr, ttl } => { + assert_eq!(domain, "google.com"); + assert_eq!(addr, Ipv6Addr::new( + 0x2001, 0xdb8, 0x0010, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0001, + )); + assert_eq!(ttl, 293); + } + _ => panic!("Record is not of type AAAA"), + } + } + + #[test] + fn test_write_aaaa_record() { + let record = DnsRecord::AAAA { + domain: "google.com".to_string(), + addr: Ipv6Addr::new( + 0x2001, 0xdb8, 0x0010, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0001, + ), + ttl: 293, + }; + + let mut buf = BytePacketBuffer::new(); + record.write(&mut buf).expect("Failed to write AAAA record"); + + assert_eq!(buf.buf[0], 6); + assert_eq!(&buf.buf[1..7], b"google"); + assert_eq!(buf.buf[7], 3); + assert_eq!(&buf.buf[8..11], b"com"); + assert_eq!(buf.buf[11], 0); + assert_eq!(buf.buf[12], 0x00); + assert_eq!(buf.buf[13], 0x1C); // qtype AAAA + assert_eq!(buf.buf[14], 0x00); + assert_eq!(buf.buf[15], 0x01); // class IN + assert_eq!(buf.buf[16], 0x00); + assert_eq!(buf.buf[17], 0x00); + assert_eq!(buf.buf[18], 0x01); + assert_eq!(buf.buf[19], 0x25); // ttl = 293 + assert_eq!(buf.buf[20], 0x00); + assert_eq!(buf.buf[21], 0x10); // data_len = 16 + assert_eq!(&buf.buf[22..38], &[ + 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x10, 0x00, 0x00, + 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + ]); + } + + #[test] + fn test_aaaa_record() { + let record = DnsRecord::AAAA { + domain: "google.com".to_string(), + addr: Ipv6Addr::new( + 0x2001, 0xdb8, 0x0010, 0x0000, + 0x0000, 0x0000, 0x0000, 0x0001, + ), + ttl: 293, + }; + + let mut buf = BytePacketBuffer::new(); + record.write(&mut buf).expect("Failed to write AAAA record"); + assert!(buf.seek(0).is_ok()); + let read_record = DnsRecord::read(&mut buf).expect("Failed to read AAAA record"); + + assert_eq!(record, read_record); + } + + #[test] + fn test_read_cname_record() { + let mut buf = BytePacketBuffer::new(); + + buf.buf[0] = 6; + buf.buf[1..7].copy_from_slice(b"google"); + buf.buf[7] = 3; + buf.buf[8..11].copy_from_slice(b"com"); + buf.buf[11] = 0; + buf.buf[12] = 0x00; + buf.buf[13] = 0x05; // qtype CNAME + buf.buf[14] = 0x00; + buf.buf[15] = 0x01; // class IN + buf.buf[16] = 0x00; + buf.buf[17] = 0x00; + buf.buf[18] = 0x01; + buf.buf[19] = 0x25; // ttl = 293 + buf.buf[20] = 0x00; + buf.buf[21] = 0x05; // data_len = 3 + buf.buf[22] = 0x03; + buf.buf[23..26].copy_from_slice(b"www"); + + let record = DnsRecord::read(&mut buf).expect("Failed to read CNAME record"); + + match record { + DnsRecord::CNAME { domain, host, ttl } => { + assert_eq!(domain, "google.com"); + assert_eq!(host, "www"); + assert_eq!(ttl, 293); + } + _ => panic!("Record is not of type CNAME"), + } + } + + #[test] + fn test_write_cname_record() { + let record = DnsRecord::CNAME { + domain: "google.com".to_string(), + host: "www".to_string(), + ttl: 293, + }; + + let mut buf = BytePacketBuffer::new(); + record.write(&mut buf).expect("Failed to write CNAME record"); + + assert_eq!(buf.buf[0], 6); + assert_eq!(&buf.buf[1..7], b"google"); + assert_eq!(buf.buf[7], 3); + assert_eq!(&buf.buf[8..11], b"com"); + assert_eq!(buf.buf[11], 0); + assert_eq!(buf.buf[12], 0x00); + assert_eq!(buf.buf[13], 0x05); // qtype CNAME + assert_eq!(buf.buf[14], 0x00); + assert_eq!(buf.buf[15], 0x01); // class IN + assert_eq!(buf.buf[16], 0x00); + assert_eq!(buf.buf[17], 0x00); + assert_eq!(buf.buf[18], 0x01); + assert_eq!(buf.buf[19], 0x25); // ttl = 293 + assert_eq!(buf.buf[20], 0x00); + assert_eq!(buf.buf[21], 0x05); // data_len = 3 + assert_eq!(buf.buf[22], 3); + assert_eq!(&buf.buf[23..26], b"www"); + } + + #[test] + fn test_cname_record() { + let record = DnsRecord::CNAME { + domain: "google.com".to_string(), + host: "www".to_string(), + ttl: 293, + }; + + let mut buf = BytePacketBuffer::new(); + record.write(&mut buf).expect("Failed to write CNAME record"); + assert!(buf.seek(0).is_ok()); + let read_record = DnsRecord::read(&mut buf).expect("Failed to read CNAME record"); + + assert_eq!(record, read_record); + } + + #[test] + fn test_read_mx_record() { + let mut buf = BytePacketBuffer::new(); + + buf.buf[0] = 6; + buf.buf[1..7].copy_from_slice(b"google"); + buf.buf[7] = 3; + buf.buf[8..11].copy_from_slice(b"com"); + buf.buf[11] = 0; + buf.buf[12] = 0x00; + buf.buf[13] = 0x0F; // qtype MX + buf.buf[14] = 0x00; + buf.buf[15] = 0x01; // class IN + buf.buf[16] = 0x00; + buf.buf[17] = 0x00; + buf.buf[18] = 0x01; + buf.buf[19] = 0x25; // ttl = 293 + buf.buf[20] = 0x00; + buf.buf[21] = 0x09; // data_len = 9 + buf.buf[22] = 0x00; + buf.buf[23] = 0x0A; // priority = 10 + buf.buf[24] = 0x03; + buf.buf[25..28].copy_from_slice(b"mx1"); + buf.buf[28] = 3; + buf.buf[29..32].copy_from_slice(b"com"); + buf.buf[32] = 0; + + let record = DnsRecord::read(&mut buf).expect("Failed to read MX record"); + + match record { + DnsRecord::MX { domain, priority, host, ttl } => { + assert_eq!(domain, "google.com"); + assert_eq!(priority, 10); + assert_eq!(host, "mx1.com"); + assert_eq!(ttl, 293); + } + _ => panic!("Record is not of type MX"), + } + } + + #[test] + fn test_write_mx_record() { + let record = DnsRecord::MX { + domain: "google.com".to_string(), + priority: 10, + host: "mx1.com".to_string(), + ttl: 293, + }; + + let mut buf = BytePacketBuffer::new(); + record.write(&mut buf).expect("Failed to write MX record"); + + assert_eq!(buf.buf[0], 6); + assert_eq!(&buf.buf[1..7], b"google"); + assert_eq!(buf.buf[7], 3); + assert_eq!(&buf.buf[8..11], b"com"); + assert_eq!(buf.buf[11], 0); + assert_eq!(buf.buf[12], 0x00); + assert_eq!(buf.buf[13], 0x0F); // qtype MX + assert_eq!(buf.buf[14], 0x00); + assert_eq!(buf.buf[15], 0x01); // class IN + assert_eq!(buf.buf[16], 0x00); + assert_eq!(buf.buf[17], 0x00); + assert_eq!(buf.buf[18], 0x01); + assert_eq!(buf.buf[19], 0x25); // ttl = 293 + assert_eq!(buf.buf[20], 0x00); + assert_eq!(buf.buf[21], 0x0b); // data_len = 9 + assert_eq!(buf.buf[22], 0x00); + assert_eq!(buf.buf[23], 0x0A); // priority = 10 + assert_eq!(buf.buf[24], 0x03); + assert_eq!(&buf.buf[25..28], b"mx1"); + assert_eq!(buf.buf[28], 3); + assert_eq!(&buf.buf[29..32], b"com"); + assert_eq!(buf.buf[32], 0); + } + + #[test] + fn test_mx_record() { + let record = DnsRecord::MX { + domain: "google.com".to_string(), + priority: 10, + host: "mx1.com".to_string(), + ttl: 293, + }; + + let mut buf = BytePacketBuffer::new(); + record.write(&mut buf).expect("Failed to write MX record"); + assert!(buf.seek(0).is_ok()); + let read_record = DnsRecord::read(&mut buf).expect("Failed to read MX record"); + + assert_eq!(record, read_record); + } + + #[test] + fn test_read_ns_record() { + let mut buf = BytePacketBuffer::new(); + + buf.buf[0] = 6; + buf.buf[1..7].copy_from_slice(b"google"); + buf.buf[7] = 3; + buf.buf[8..11].copy_from_slice(b"com"); + buf.buf[11] = 0; + buf.buf[12] = 0x00; + buf.buf[13] = 0x02; // qtype NS + buf.buf[14] = 0x00; + buf.buf[15] = 0x01; // class IN + buf.buf[16] = 0x00; + buf.buf[17] = 0x00; + buf.buf[18] = 0x01; + buf.buf[19] = 0x25; // ttl = 293 + buf.buf[20] = 0x00; + buf.buf[21] = 0x09; // data_len = 9 + buf.buf[22] = 0x03; + buf.buf[23..26].copy_from_slice(b"ns1"); + buf.buf[26] = 3; + buf.buf[27..30].copy_from_slice(b"com"); + buf.buf[30] = 0; + + let record = DnsRecord::read(&mut buf).expect("Failed to read NS record"); + + match record { + DnsRecord::NS { domain, host, ttl } => { + assert_eq!(domain, "google.com"); + assert_eq!(host, "ns1.com"); + assert_eq!(ttl, 293); + } + _ => panic!("Record is not of type NS"), + } + } + + #[test] + fn test_write_ns_record() { + let record = DnsRecord::NS { + domain: "google.com".to_string(), + host: "ns1.com".to_string(), + ttl: 293, + }; + + let mut buf = BytePacketBuffer::new(); + record.write(&mut buf).expect("Failed to write NS record"); + + assert_eq!(buf.buf[0], 6); + assert_eq!(&buf.buf[1..7], b"google"); + assert_eq!(buf.buf[7], 3); + assert_eq!(&buf.buf[8..11], b"com"); + assert_eq!(buf.buf[11], 0); + assert_eq!(buf.buf[12], 0x00); + assert_eq!(buf.buf[13], 0x02); // qtype NS + assert_eq!(buf.buf[14], 0x00); + assert_eq!(buf.buf[15], 0x01); // class IN + assert_eq!(buf.buf[16], 0x00); + assert_eq!(buf.buf[17], 0x00); + assert_eq!(buf.buf[18], 0x01); + assert_eq!(buf.buf[19], 0x25); // ttl = 293 + assert_eq!(buf.buf[20], 0x00); + assert_eq!(buf.buf[21], 0x09); // data_len = 9 + assert_eq!(buf.buf[22], 0x03); + assert_eq!(&buf.buf[23..26], b"ns1"); + assert_eq!(buf.buf[26], 3); + assert_eq!(&buf.buf[27..30], b"com"); + assert_eq!(buf.buf[30], 0); + } + + #[test] + fn test_ns_record() { + let record = DnsRecord::NS { + domain: "google.com".to_string(), + host: "ns1.com".to_string(), + ttl: 293, + }; + + let mut buf = BytePacketBuffer::new(); + record.write(&mut buf).expect("Failed to write NS record"); + assert!(buf.seek(0).is_ok()); + let read_record = DnsRecord::read(&mut buf).expect("Failed to read NS record"); + + assert_eq!(record, read_record); + } + +} \ No newline at end of file diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..44b14f1 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,127 @@ +use std::{io::{self}, net::{Ipv4Addr, UdpSocket}}; +mod byte_packet_buffer; +mod dns_header; +mod dns_question; +mod dns_record; +mod query_types; +mod dns_packet; +use {byte_packet_buffer::BytePacketBuffer, dns_header::ResultCode, dns_question::DnsQuestion, query_types::QueryType, dns_packet::DnsPacket}; + + +fn recursive_lookup(qname: &str, qtype: QueryType) -> Result { + // For now we're always starting with *a.root-servers.net*. + let mut ns = "198.41.0.4".parse::().unwrap(); + + loop { + println!("attempting lookup of {:?} {} with ns {}", qtype, qname, ns); + + let ns_copy = ns; + + let server = (ns_copy, 53); + let response = lookup(qname, qtype, server)?; + + if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR { + return Ok(response); + } + + if response.header.rescode == ResultCode::NXDOMAIN { + return Ok(response); + } + + if let Some(new_ns) = response.get_resolved_ns(qname) { + ns = new_ns; + + continue; + } + + let new_ns_name = match response.get_unresolved_ns(qname) { + Some(x) => x, + None => return Ok(response), + }; + + let recursive_response = recursive_lookup(&new_ns_name, QueryType::A)?; + + if let Some(new_ns) = recursive_response.get_random_a() { + ns = new_ns; + } else { + return Ok(response); + } + } +} + +fn lookup(qname: &str, qtype: QueryType, server: (Ipv4Addr, u16)) -> Result { + let socket = UdpSocket::bind(("0.0.0.0", 43210))?; + + let mut packet = DnsPacket::new(); + + packet.header.id = 6666; + packet.header.questions = 1; + packet.header.recursion_desired = true; + packet + .questions + .push(DnsQuestion::new(qname.to_string(), qtype)); + + let mut req_buffer = BytePacketBuffer::new(); + packet.write(&mut req_buffer)?; + socket.send_to(&req_buffer.get_range(0, req_buffer.pos()+1)?, server)?; + + let mut res_buffer = BytePacketBuffer::new(); + socket.recv_from(&mut res_buffer.buf)?; + + DnsPacket::from_buffer(&mut res_buffer) +} + +/// Handle query +/// function handles an incoming DNS query from a UDP socket +/// and formulates an appropriate DNS response, +/// either answering directly or using recursive resolution to obtain the necessary data +pub fn handle_query(socket: &UdpSocket) -> Result<(), io::Error> { + let mut req_buffer = BytePacketBuffer::new(); + + + let (_, src) = socket.recv_from(&mut req_buffer.buf)?; + + + let mut request = DnsPacket::from_buffer(&mut req_buffer)?; + + let mut packet = DnsPacket::new(); + packet.header.id = request.header.id; + packet.header.recursion_desired = true; + packet.header.recursion_available = true; + packet.header.response = true; + + if let Some(question) = request.questions.pop() { + println!("Received query: {:?}", question); + if let Ok(result) = recursive_lookup(&question.name, question.qtype) { + packet.questions.push(question.clone()); + packet.header.rescode = result.header.rescode; + + for rec in result.answers { + println!("Answer: {:?}", rec); + packet.answers.push(rec); + } + for rec in result.authorities { + println!("Authority: {:?}", rec); + packet.authorities.push(rec); + } + for rec in result.resources { + println!("Resource: {:?}", rec); + packet.resources.push(rec); + } + } else { + packet.header.rescode = ResultCode::SERVFAIL; + } + } else { + packet.header.rescode = ResultCode::FORMERR; + } + + let mut res_buffer = BytePacketBuffer::new(); + packet.write(&mut res_buffer)?; + + let len = res_buffer.pos(); + let data = res_buffer.get_range(0, len)?; + + socket.send_to(data, src)?; + + Ok(()) +} \ No newline at end of file diff --git a/src/main.rs b/src/main.rs index 753d4b6..5f4bf93 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,833 +1,6 @@ -use std::{fs::File, io::{self, Error, Read}, net::{Ipv4Addr, Ipv6Addr, UdpSocket}}; +use std::{io, net::UdpSocket}; -pub struct BytePacketBuffer { - pub buf: [u8; 512], - pub pos: usize, -} - -impl BytePacketBuffer { - pub fn new () -> BytePacketBuffer { - return BytePacketBuffer { - buf: [0; 512], - pos: 0 - }; - } - fn pos(&self) -> usize{ - return self.pos; - } - fn step(&mut self, steps: usize) -> Result<(), io::Error> { - self.pos += steps; - Ok(()) - } - fn seek(&mut self, pos: usize) -> Result<(), io::Error> { - self.pos = pos; - - Ok(()) - } - fn read(&mut self) -> Result { - if self.pos >= 512 { - return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); - } - let res = self.buf[self.pos]; - self.pos += 1; - Ok(res) - } - fn get(& self, pos: usize) -> Result { - if pos >= 512 { - return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); - } - Ok(self.buf[pos]) - } - fn get_range(& self, start: usize, len: usize) -> Result<&[u8], io::Error> { - if start+len >= 512 { - return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); - } - Ok(&self.buf[start..start+len as usize]) - } - fn read_u16(&mut self) -> Result { - let res = ( ( self.read()? as u16) << 8) | ( self.read()? as u16); - Ok(res) - } - fn read_u32(&mut self) -> Result { - let res = ((self.read()? as u32) << 24) - | ((self.read()? as u32) << 16) - | ((self.read()? as u32) << 8) - | ((self.read()? as u32) << 0); - - Ok(res) - } - fn read_qname(&mut self, outstr: &mut String) -> Result<(), io::Error> { - - let mut pos = self.pos(); - let mut jumped = false; - let max_jumps = 5; - let mut jumps_performed = 0; - let mut delim = ""; - loop { - if jumps_performed > max_jumps { - return Err(Error::new(io::ErrorKind::UnexpectedEof, format!("Limit of {} jumps exceeded", max_jumps))); - } - - let len = self.get(pos)?; - - if (len & 0xC0) == 0xC0 { - if !jumped { - self.seek(pos+2)? - } - - let b2 = self.get(pos + 1)? as u16; - let offset = (((len as u16) ^ 0xC0) << 8) | b2; - pos = offset as usize; - - jumped = true; - jumps_performed += 1; - }else { - pos += 1; - - if len == 0 { - break; - } - outstr.push_str(delim); - let str_buffer = self.get_range(pos, len as usize)?; - outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase()); - - delim = "."; - pos += len as usize; - } - } - if !jumped { - self.seek(pos)?; - } - - Ok(()) - } - fn write(&mut self, val: u8) -> Result<(), io::Error> { - if self.pos >= 512 { - return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); - } - self.buf[self.pos] = val; - self.pos += 1; - Ok(()) - } - - fn write_u8(&mut self, val: u8) -> Result<(), io::Error> { - self.write(val)?; - - Ok(()) - } - - fn write_u16(&mut self, val: u16) -> Result<(), io::Error> { - self.write((val >> 8) as u8)?; - self.write((val & 0xFF) as u8)?; - - Ok(()) - } - - fn write_u32(&mut self, val: u32) -> Result<(), io::Error> { - self.write(((val >> 24) & 0xFF) as u8)?; - self.write(((val >> 16) & 0xFF) as u8)?; - self.write(((val >> 8) & 0xFF) as u8)?; - self.write(((val >> 0) & 0xFF) as u8)?; - - Ok(()) - } - fn write_qname(&mut self, qname: &str) -> Result<(), io::Error> { - for label in qname.split('.') { - let len = label.len(); - if len > 0x3f { - return Err(Error::new(io::ErrorKind::UnexpectedEof, "Single label exceeds 63 characters of length")); - } - - self.write_u8(len as u8)?; - for b in label.as_bytes() { - self.write_u8(*b)?; - } - } - - self.write_u8(0)?; - - Ok(()) - } - - fn set(&mut self, pos: usize, val: u8) -> Result<(), io::Error> { - self.buf[pos] = val; - - Ok(()) - } - - fn set_u16(&mut self, pos: usize, val: u16) -> Result<(), io::Error> { - self.set(pos, (val >> 8) as u8)?; - self.set(pos + 1, (val & 0xFF) as u8)?; - - Ok(()) - } -} - -#[derive(Copy, Clone, Debug, PartialEq, Eq)] -pub enum ResultCode { - NOERROR = 0, - FORMERR = 1, - SERVFAIL = 2, - NXDOMAIN = 3, - NOTIMP = 4, - REFUSED = 5, -} - -impl ResultCode { - pub fn from_num(num: u8) -> ResultCode { - match num { - 1 => ResultCode::FORMERR, - 2 => ResultCode::SERVFAIL, - 3 => ResultCode::NXDOMAIN, - 4 => ResultCode::NOTIMP, - 5 => ResultCode::REFUSED, - 0 | _ => ResultCode::NOERROR, - } - } -} - - -#[derive(Clone, Debug)] -pub struct DnsHeader { - pub id: u16, // 16 bits - - pub recursion_desired: bool, // 1 bit - pub truncated_message: bool, // 1 bit - pub authoritative_answer: bool, // 1 bit - pub opcode: u8, // 4 bits - pub response: bool, // 1 bit - - pub rescode: ResultCode, // 4 bits - pub checking_disabled: bool, // 1 bit - pub authed_data: bool, // 1 bit - pub z: bool, // 1 bit - pub recursion_available: bool, // 1 bit - - pub questions: u16, // 16 bits - pub answers: u16, // 16 bits - pub authoritative_entries: u16, // 16 bits - pub resource_entries: u16, // 16 bits -} - -impl DnsHeader { - pub fn new() -> DnsHeader { - DnsHeader { - id: 0, - - recursion_desired: false, - truncated_message: false, - authoritative_answer: false, - opcode: 0, - response: false, - - rescode: ResultCode::NOERROR, - checking_disabled: false, - authed_data: false, - z: false, - recursion_available: false, - - questions: 0, - answers: 0, - authoritative_entries: 0, - resource_entries: 0, - } - } - - pub fn read(&mut self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { - self.id = buffer.read_u16()?; - - let mut flags = buffer.read()?; - self.recursion_desired = (flags & (1 << 0)) > 0; - self.truncated_message = (flags & (1 << 1)) > 0; - self.authoritative_answer = (flags & (1 << 2)) > 0; - self.opcode = (flags >> 3) & 0x0F; - self.response = (flags & (1 << 7)) > 0; - - flags = buffer.read()?; - self.rescode = ResultCode::from_num(flags & 0x0F); - self.checking_disabled = (flags & (1 << 4)) > 0; - self.authed_data = (flags & (1 << 5)) > 0; - self.z = (flags & (1 << 6)) > 0; - self.recursion_available = (flags & (1 << 7)) > 0; - - self.questions = buffer.read_u16()?; - self.answers = buffer.read_u16()?; - self.authoritative_entries = buffer.read_u16()?; - self.resource_entries = buffer.read_u16()?; - - Ok(()) - } - pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { - buffer.write_u16(self.id)?; - - buffer.write_u8( - (self.recursion_desired as u8) - | ((self.truncated_message as u8) << 1) - | ((self.authoritative_answer as u8) << 2) - | (self.opcode << 3) - | ((self.response as u8) << 7) as u8, - )?; - - buffer.write_u8( - (self.rescode as u8) - | ((self.checking_disabled as u8) << 4) - | ((self.authed_data as u8) << 5) - | ((self.z as u8) << 6) - | ((self.recursion_available as u8) << 7), - )?; - - buffer.write_u16(self.questions)?; - buffer.write_u16(self.answers)?; - buffer.write_u16(self.authoritative_entries)?; - buffer.write_u16(self.resource_entries)?; - - Ok(()) - } -} - -#[derive(PartialEq, Eq, Debug, Clone, Hash, Copy)] -pub enum QueryType { - UNKNOWN(u16), - A, // 1 - NS, // 2 - CNAME, // 5 - MX, // 15 - AAAA, // 28 -} - -impl QueryType { - pub fn to_num(&self) -> u16 { - match *self { - QueryType::UNKNOWN(x) => x, - QueryType::A => 1, - QueryType::NS => 2, - QueryType::CNAME => 5, - QueryType::MX => 15, - QueryType::AAAA => 28, - } - } - - pub fn from_num(num: u16) -> QueryType { - match num { - 1 => QueryType::A, - 2 => QueryType::NS, - 5 => QueryType::CNAME, - 15 => QueryType::MX, - 28 => QueryType::AAAA, - _ => QueryType::UNKNOWN(num), - } - } -} - -#[derive(Debug, Clone, PartialEq, Eq)] -pub struct DnsQuestion { - pub name: String, - pub qtype: QueryType, -} - -impl DnsQuestion { - pub fn new(name: String, qtype: QueryType) -> DnsQuestion { - DnsQuestion { - name: name, - qtype: qtype, - } - } - - pub fn read(&mut self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { - buffer.read_qname(&mut self.name)?; - self.qtype = QueryType::from_num(buffer.read_u16()?); // qtype - let _ = buffer.read_u16()?; // class - - Ok(()) - } - - pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { - buffer.write_qname(&self.name)?; - - let typenum = self.qtype.to_num(); - buffer.write_u16(typenum)?; - buffer.write_u16(1)?; - - Ok(()) - } -} - -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] -#[allow(dead_code)] -pub enum DnsRecord { - UNKNOWN { - domain: String, - qtype: u16, - data_len: u16, - ttl: u32, - }, // 0 - A { - domain: String, - addr: Ipv4Addr, - ttl: u32, - }, // 1 - NS { - domain: String, - host: String, - ttl: u32, - }, // 2 - CNAME { - domain: String, - host: String, - ttl: u32, - }, // 5 - MX { - domain: String, - priority: u16, - host: String, - ttl: u32, - }, // 15 - AAAA { - domain: String, - addr: Ipv6Addr, - ttl: u32, - }, // 28 -} -impl DnsRecord { - pub fn read(buffer: &mut BytePacketBuffer) -> Result { - let mut domain = String::new(); - buffer.read_qname(&mut domain)?; - - let qtype_num = buffer.read_u16()?; - let qtype = QueryType::from_num(qtype_num); - let _ = buffer.read_u16()?; - let ttl = buffer.read_u32()?; - let data_len = buffer.read_u16()?; - - match qtype { - QueryType::A => { - let raw_addr = buffer.read_u32()?; - let addr = Ipv4Addr::new( - ((raw_addr >> 24) & 0xFF) as u8, - ((raw_addr >> 16) & 0xFF) as u8, - ((raw_addr >> 8) & 0xFF) as u8, - ((raw_addr >> 0) & 0xFF) as u8, - ); - - Ok(DnsRecord::A { - domain: domain, - addr: addr, - ttl: ttl, - }) - } - QueryType::AAAA => { - let raw_addr1 = buffer.read_u32()?; - let raw_addr2 = buffer.read_u32()?; - let raw_addr3 = buffer.read_u32()?; - let raw_addr4 = buffer.read_u32()?; - let addr = Ipv6Addr::new( - ((raw_addr1 >> 16) & 0xFFFF) as u16, - ((raw_addr1 >> 0) & 0xFFFF) as u16, - ((raw_addr2 >> 16) & 0xFFFF) as u16, - ((raw_addr2 >> 0) & 0xFFFF) as u16, - ((raw_addr3 >> 16) & 0xFFFF) as u16, - ((raw_addr3 >> 0) & 0xFFFF) as u16, - ((raw_addr4 >> 16) & 0xFFFF) as u16, - ((raw_addr4 >> 0) & 0xFFFF) as u16, - ); - - Ok(DnsRecord::AAAA { - domain: domain, - addr: addr, - ttl: ttl, - }) - } - QueryType::NS => { - let mut ns = String::new(); - buffer.read_qname(&mut ns)?; - - Ok(DnsRecord::NS { - domain: domain, - host: ns, - ttl: ttl, - }) - } - QueryType::CNAME => { - let mut cname = String::new(); - buffer.read_qname(&mut cname)?; - - Ok(DnsRecord::CNAME { - domain: domain, - host: cname, - ttl: ttl, - }) - } - QueryType::MX => { - let priority = buffer.read_u16()?; - let mut mx = String::new(); - buffer.read_qname(&mut mx)?; - - Ok(DnsRecord::MX { - domain: domain, - priority: priority, - host: mx, - ttl: ttl, - }) - } - QueryType::UNKNOWN(_) => { - buffer.step(data_len as usize)?; - - Ok(DnsRecord::UNKNOWN { - domain: domain, - qtype: qtype_num, - data_len: data_len, - ttl: ttl, - }) - } - } - } - - pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result { - let start_pos = buffer.pos(); - - match *self { - DnsRecord::A { - ref domain, - ref addr, - ttl, - } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::A.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - buffer.write_u16(4)?; - - let octets = addr.octets(); - buffer.write_u8(octets[0])?; - buffer.write_u8(octets[1])?; - buffer.write_u8(octets[2])?; - buffer.write_u8(octets[3])?; - } - DnsRecord::NS { - ref domain, - ref host, - ttl, - } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::NS.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - - let pos = buffer.pos(); - buffer.write_u16(0)?; - - buffer.write_qname(host)?; - - let size = buffer.pos() - (pos + 2); - buffer.set_u16(pos, size as u16)?; - } - DnsRecord::CNAME { - ref domain, - ref host, - ttl, - } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::CNAME.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - - let pos = buffer.pos(); - buffer.write_u16(0)?; - - buffer.write_qname(host)?; - - let size = buffer.pos() - (pos + 2); - buffer.set_u16(pos, size as u16)?; - } - DnsRecord::MX { - ref domain, - priority, - ref host, - ttl, - } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::MX.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - - let pos = buffer.pos(); - buffer.write_u16(0)?; - - buffer.write_u16(priority)?; - buffer.write_qname(host)?; - - let size = buffer.pos() - (pos + 2); - buffer.set_u16(pos, size as u16)?; - } - DnsRecord::AAAA { - ref domain, - ref addr, - ttl, - } => { - buffer.write_qname(domain)?; - buffer.write_u16(QueryType::AAAA.to_num())?; - buffer.write_u16(1)?; - buffer.write_u32(ttl)?; - buffer.write_u16(16)?; - - for octet in &addr.segments() { - buffer.write_u16(*octet)?; - } - } - DnsRecord::UNKNOWN { .. } => { - println!("Skipping record: {:?}", self); - } - } - - Ok(buffer.pos() - start_pos) - } -} - - -#[derive(Clone, Debug)] -pub struct DnsPacket { - pub header: DnsHeader, - pub questions: Vec, - pub answers: Vec, - pub authorities: Vec, - pub resources: Vec, -} - -impl DnsPacket { - pub fn new() -> DnsPacket { - DnsPacket { - header: DnsHeader::new(), - questions: Vec::new(), - answers: Vec::new(), - authorities: Vec::new(), - resources: Vec::new(), - } - } - - pub fn from_buffer(buffer: &mut BytePacketBuffer) -> Result { - let mut result = DnsPacket::new(); - result.header.read(buffer)?; - - for _ in 0..result.header.questions { - let mut question = DnsQuestion::new("".to_string(), QueryType::UNKNOWN(0)); - question.read(buffer)?; - result.questions.push(question); - } - - for _ in 0..result.header.answers { - let rec = DnsRecord::read(buffer)?; - result.answers.push(rec); - } - for _ in 0..result.header.authoritative_entries { - let rec = DnsRecord::read(buffer)?; - result.authorities.push(rec); - } - for _ in 0..result.header.resource_entries { - let rec = DnsRecord::read(buffer)?; - result.resources.push(rec); - } - - Ok(result) - } - - pub fn write(&mut self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { - self.header.questions = self.questions.len() as u16; - self.header.answers = self.answers.len() as u16; - self.header.authoritative_entries = self.authorities.len() as u16; - self.header.resource_entries = self.resources.len() as u16; - - self.header.write(buffer)?; - - for question in &self.questions { - question.write(buffer)?; - } - for rec in &self.answers { - rec.write(buffer)?; - } - for rec in &self.authorities { - rec.write(buffer)?; - } - for rec in &self.resources { - rec.write(buffer)?; - } - - Ok(()) - } - - pub fn get_random_a(&self) -> Option { - self.answers - .iter() - .filter_map(|record| match record { - DnsRecord::A { addr, .. } => Some(*addr), - _ => None, - }) - .next() - } - - fn get_ns<'a>(&'a self, qname: &'a str) -> impl Iterator { - self.authorities - .iter() - .filter_map(|record| match record { - DnsRecord::NS { domain, host, .. } => Some((domain.as_str(), host.as_str())), - _ => None, - }) - .filter(move |(domain, _)| qname.ends_with(*domain)) - } - - - pub fn get_resolved_ns(&self, qname: &str) -> Option { - self.get_ns(qname) - .flat_map(|(_, host)| { - self.resources - .iter() - .filter_map(move |record| match record { - DnsRecord::A { domain, addr, .. } if domain == host => Some(addr), - _ => None, - }) - }) - .map(|addr| *addr) - .next() - } - - pub fn get_unresolved_ns<'a>(&'a self, qname: &'a str) -> Option<&'a str> { - self.get_ns(qname) - .map(|(_, host)| host) - .next() - } -} - - -fn recursive_lookup(qname: &str, qtype: QueryType) -> Result { - // For now we're always starting with *a.root-servers.net*. - let mut ns = "198.41.0.4".parse::().unwrap(); - - loop { - println!("attempting lookup of {:?} {} with ns {}", qtype, qname, ns); - - // send the query to the active server. - let ns_copy = ns; - - let server = (ns_copy, 53); - let response = lookup(qname, qtype, server)?; - - // If there are entries in the answer section, and no errors, we are done! - if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR { - return Ok(response); - } - - // We might also get a `NXDOMAIN` reply, which is the authoritative name servers - // way of telling us that the name doesn't exist. - if response.header.rescode == ResultCode::NXDOMAIN { - return Ok(response); - } - - // Otherwise, we'll try to find a new nameserver based on NS and a corresponding A - // record in the additional section. If this succeeds, we can switch name server - // and retry the loop. - if let Some(new_ns) = response.get_resolved_ns(qname) { - ns = new_ns; - - continue; - } - - // If not, we'll have to resolve the ip of a NS record. If no NS records exist, - // we'll go with what the last server told us. - let new_ns_name = match response.get_unresolved_ns(qname) { - Some(x) => x, - None => return Ok(response), - }; - - // Here we go down the rabbit hole by starting _another_ lookup sequence in the - // midst of our current one. Hopefully, this will give us the IP of an appropriate - // name server. - let recursive_response = recursive_lookup(&new_ns_name, QueryType::A)?; - - // Finally, we pick a random ip from the result, and restart the loop. If no such - // record is available, we again return the last result we got. - if let Some(new_ns) = recursive_response.get_random_a() { - ns = new_ns; - } else { - return Ok(response); - } - } -} - -fn lookup(qname: &str, qtype: QueryType, server: (Ipv4Addr, u16)) -> Result { - // Forward queries to Google's public DNS - //let server = ("8.8.8.8", 53); - - let socket = UdpSocket::bind(("0.0.0.0", 43210))?; - - let mut packet = DnsPacket::new(); - - packet.header.id = 6666; - packet.header.questions = 1; - packet.header.recursion_desired = true; - packet - .questions - .push(DnsQuestion::new(qname.to_string(), qtype)); - - let mut req_buffer = BytePacketBuffer::new(); - packet.write(&mut req_buffer)?; - socket.send_to(&req_buffer.buf[0..req_buffer.pos], server)?; - - let mut res_buffer = BytePacketBuffer::new(); - socket.recv_from(&mut res_buffer.buf)?; - - DnsPacket::from_buffer(&mut res_buffer) -} - -fn handle_query(socket: &UdpSocket) -> Result<(), io::Error> { - let mut req_buffer = BytePacketBuffer::new(); - - - let (_, src) = socket.recv_from(&mut req_buffer.buf)?; - - - let mut request = DnsPacket::from_buffer(&mut req_buffer)?; - - let mut packet = DnsPacket::new(); - packet.header.id = request.header.id; - packet.header.recursion_desired = true; - packet.header.recursion_available = true; - packet.header.response = true; - - if let Some(question) = request.questions.pop() { - println!("Received query: {:?}", question); - if let Ok(result) = recursive_lookup(&question.name, question.qtype) { - packet.questions.push(question.clone()); - packet.header.rescode = result.header.rescode; - - for rec in result.answers { - println!("Answer: {:?}", rec); - packet.answers.push(rec); - } - for rec in result.authorities { - println!("Authority: {:?}", rec); - packet.authorities.push(rec); - } - for rec in result.resources { - println!("Resource: {:?}", rec); - packet.resources.push(rec); - } - } else { - packet.header.rescode = ResultCode::SERVFAIL; - } - } else { - packet.header.rescode = ResultCode::FORMERR; - } - - let mut res_buffer = BytePacketBuffer::new(); - packet.write(&mut res_buffer)?; - - let len = res_buffer.pos(); - let data = res_buffer.get_range(0, len)?; - - socket.send_to(data, src)?; - - Ok(()) -} fn main() -> Result<(), io::Error> { // Bind an UDP socket on port 2053 @@ -836,7 +9,7 @@ fn main() -> Result<(), io::Error> { // For now, queries are handled sequentially, so an infinite loop for servicing // requests is initiated. loop { - match handle_query(&socket) { + match dnsserver_nabil::handle_query(&socket) { Ok(_) => {}, Err(e) => eprintln!("An error occurred: {}", e), } diff --git a/src/query_types.rs b/src/query_types.rs new file mode 100644 index 0000000..ac993bd --- /dev/null +++ b/src/query_types.rs @@ -0,0 +1,33 @@ +#[derive(PartialEq, Eq, Debug, Clone, Hash, Copy)] +pub enum QueryType { + UNKNOWN(u16), + A, // 1 + NS, // 2 + CNAME, // 5 + MX, // 15 + AAAA, // 28 +} + +impl QueryType { + pub fn to_num(&self) -> u16 { + match *self { + QueryType::UNKNOWN(x) => x, + QueryType::A => 1, + QueryType::NS => 2, + QueryType::CNAME => 5, + QueryType::MX => 15, + QueryType::AAAA => 28, + } + } + + pub fn from_num(num: u16) -> QueryType { + match num { + 1 => QueryType::A, + 2 => QueryType::NS, + 5 => QueryType::CNAME, + 15 => QueryType::MX, + 28 => QueryType::AAAA, + _ => QueryType::UNKNOWN(num), + } + } +} \ No newline at end of file From 84bca38ef9e2a2db8702ebb9f2e6992b6c7c092e Mon Sep 17 00:00:00 2001 From: nabil salah Date: Sat, 21 Dec 2024 22:14:12 +0200 Subject: [PATCH 03/12] feat: build ci/cd pipeline Signed-off-by: nabil salah --- .github/workflows/rust.yaml | 38 +++++++++++++++++ .gitignore | 3 +- src/byte_packet_buffer.rs | 57 +++++++++++++------------- src/dns_header.rs | 17 ++++---- src/dns_packet.rs | 23 ++++++----- src/dns_question.rs | 8 ++-- src/dns_record.rs | 81 +++++++++++++++++++++---------------- src/lib.rs | 31 +++++++------- src/main.rs | 4 +- src/query_types.rs | 2 +- 10 files changed, 163 insertions(+), 101 deletions(-) create mode 100644 .github/workflows/rust.yaml diff --git a/.github/workflows/rust.yaml b/.github/workflows/rust.yaml new file mode 100644 index 0000000..2fe10ab --- /dev/null +++ b/.github/workflows/rust.yaml @@ -0,0 +1,38 @@ +name: Rust + +on: [push] + +jobs: + build: + name: Test-Clippy-Build + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + name: Checkout code + with: + fetch-depth: 1 + - uses: actions-rs/toolchain@v1 + name: Install toolchain + with: + toolchain: stable + - uses: actions-rs/cargo@v1 + name: Check formatting + with: + command: fmt + args: -- --check + - uses: actions-rs/cargo@v1 + name: Run clippy + with: + command: clippy + + - uses: actions-rs/cargo@v1 + name: Run tests + with: + command: test + + - uses: actions-rs/cargo@v1 + name: Build + with: + toolchain: stable + command: build + args: --release diff --git a/.gitignore b/.gitignore index 41a1821..1acac83 100644 --- a/.gitignore +++ b/.gitignore @@ -1,4 +1,5 @@ /target query_packet.txt response_packet.txt -.vscode \ No newline at end of file +.vscode +Authoritative-DNS-server.png \ No newline at end of file diff --git a/src/byte_packet_buffer.rs b/src/byte_packet_buffer.rs index f3bcaeb..a918a1c 100644 --- a/src/byte_packet_buffer.rs +++ b/src/byte_packet_buffer.rs @@ -8,16 +8,16 @@ pub struct BytePacketBuffer { impl BytePacketBuffer { /// This gives us a fresh buffer for holding the packet contents, and a /// field for keeping track of where we are. - pub fn new () -> BytePacketBuffer { - return BytePacketBuffer { + pub fn new() -> BytePacketBuffer { + BytePacketBuffer { buf: [0; 512], - pos: 0 - }; + pos: 0, + } } /// Current position within buffer - pub fn pos(&self) -> usize{ - return self.pos; + pub fn pos(&self) -> usize { + self.pos } /// Step the buffer position forward a specific number of steps @@ -49,7 +49,7 @@ impl BytePacketBuffer { } /// Get a single byte, without changing the buffer position - pub fn get(& self, pos: usize) -> Result { + pub fn get(&self, pos: usize) -> Result { if pos >= 512 { return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); } @@ -57,16 +57,16 @@ impl BytePacketBuffer { } /// Get a range of bytes - pub fn get_range(& self, start: usize, len: usize) -> Result<&[u8], io::Error> { - if start+len > 512 { + pub fn get_range(&self, start: usize, len: usize) -> Result<&[u8], io::Error> { + if start + len > 512 { return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); } - Ok(&self.buf[start..(start+len) as usize]) + Ok(&self.buf[start..(start + len)]) } /// Read two bytes, stepping two steps forward pub fn read_u16(&mut self) -> Result { - let res = ( ( self.read()? as u16) << 8) | ( self.read()? as u16); + let res = ((self.read()? as u16) << 8) | (self.read()? as u16); Ok(res) } @@ -81,13 +81,12 @@ impl BytePacketBuffer { } /// Read a qname - /// + /// /// Will take something like [3]www[6]google[3]com[0] and append /// www.google.com to outstr. - /// + /// /// also it handle jemps . pub fn read_qname(&mut self, outstr: &mut String) -> Result<(), io::Error> { - let mut pos = self.pos(); let mut jumped = false; let max_jumps = 5; @@ -95,14 +94,17 @@ impl BytePacketBuffer { let mut delim = ""; loop { if jumps_performed > max_jumps { - return Err(Error::new(io::ErrorKind::UnexpectedEof, format!("Limit of {} jumps exceeded", max_jumps))); + return Err(Error::new( + io::ErrorKind::UnexpectedEof, + format!("Limit of {} jumps exceeded", max_jumps), + )); } let len = self.get(pos)?; - + if (len & 0xC0) == 0xC0 { if !jumped { - self.seek(pos+2)? + self.seek(pos + 2)? } let b2 = self.get(pos + 1)? as u16; @@ -111,7 +113,7 @@ impl BytePacketBuffer { jumped = true; jumps_performed += 1; - }else { + } else { pos += 1; if len == 0 { @@ -168,15 +170,18 @@ impl BytePacketBuffer { } /// Write a qname - /// + /// /// Will take something like www.google.com - /// + /// /// dots are the separator used pub fn write_qname(&mut self, qname: &str) -> Result<(), io::Error> { for label in qname.split('.') { let len = label.len(); if len > 0x3f { - return Err(Error::new(io::ErrorKind::UnexpectedEof, "Single label exceeds 63 characters of length")); + return Err(Error::new( + io::ErrorKind::UnexpectedEof, + "Single label exceeds 63 characters of length", + )); } self.write_u8(len as u8)?; @@ -231,12 +236,12 @@ mod tests { #[test] fn test_read_write_limits() { let mut buf = BytePacketBuffer::new(); - + // Test reading at limit assert!(buf.seek(511).is_ok()); assert!(buf.read().is_ok()); assert!(buf.read().is_err()); - + // Test writing at limit assert!(buf.seek(511).is_ok()); assert!(buf.write(0xFF).is_ok()); @@ -251,7 +256,7 @@ mod tests { buf.buf[1] = 0x34; buf.buf[2] = 0x56; buf.buf[3] = 0x78; - + assert_eq!(buf.read_u16().unwrap(), 0x1234); assert!(buf.seek(0).is_ok()); assert_eq!(buf.read_u32().unwrap(), 0x12345678); @@ -352,7 +357,6 @@ mod tests { assert_eq!(outstr, "www.google.com"); } - #[test] fn test_write_qname_success() { let mut buf = BytePacketBuffer::new(); @@ -370,5 +374,4 @@ mod tests { assert!(buf.seek(510).is_ok()); assert!(buf.write_qname("a").is_err()); } - -} \ No newline at end of file +} diff --git a/src/dns_header.rs b/src/dns_header.rs index 18f23d4..d4468a9 100644 --- a/src/dns_header.rs +++ b/src/dns_header.rs @@ -25,8 +25,6 @@ impl ResultCode { } } - - #[derive(Clone, Debug)] pub struct DnsHeader { pub id: u16, // 16 bits @@ -133,7 +131,7 @@ mod tests { #[test] fn test_read_header_from_buffer() { let mut buf = BytePacketBuffer::new(); - + buf.buf[0] = 0x12; buf.buf[1] = 0x34; @@ -234,11 +232,13 @@ mod tests { assert!(read_header.read(&mut buf).is_ok()); println!("{read_header:?}"); - assert_eq!(header.id, read_header.id); assert_eq!(header.recursion_desired, read_header.recursion_desired); assert_eq!(header.truncated_message, read_header.truncated_message); - assert_eq!(header.authoritative_answer, read_header.authoritative_answer); + assert_eq!( + header.authoritative_answer, + read_header.authoritative_answer + ); assert_eq!(header.opcode, read_header.opcode); assert_eq!(header.response, read_header.response); assert_eq!(header.rescode, read_header.rescode); @@ -248,7 +248,10 @@ mod tests { assert_eq!(header.recursion_available, read_header.recursion_available); assert_eq!(header.questions, read_header.questions); assert_eq!(header.answers, read_header.answers); - assert_eq!(header.authoritative_entries, read_header.authoritative_entries); + assert_eq!( + header.authoritative_entries, + read_header.authoritative_entries + ); assert_eq!(header.resource_entries, read_header.resource_entries); } -} \ No newline at end of file +} diff --git a/src/dns_packet.rs b/src/dns_packet.rs index 679b2c9..d3fb2b6 100644 --- a/src/dns_packet.rs +++ b/src/dns_packet.rs @@ -1,6 +1,9 @@ use std::{io, net::Ipv4Addr}; -use crate::{byte_packet_buffer::BytePacketBuffer, dns_header::DnsHeader, dns_question::DnsQuestion, dns_record::DnsRecord, query_types::QueryType}; +use crate::{ + byte_packet_buffer::BytePacketBuffer, dns_header::DnsHeader, dns_question::DnsQuestion, + dns_record::DnsRecord, query_types::QueryType, +}; #[derive(Clone, Debug)] pub struct DnsPacket { @@ -99,7 +102,7 @@ impl DnsPacket { .filter(move |(domain, _)| qname.ends_with(*domain)) } - /// Get Resolved NS + /// Get Resolved NS /// as We'll use the fact that name servers often bundle the corresponding /// A records when replying to an NS query to implement a function that /// returns the actual IP for an NS record if possible. @@ -120,13 +123,10 @@ impl DnsPacket { /// Get Unresolved NS a method for returning the host /// name of an appropriate name server. pub fn get_unresolved_ns<'a>(&'a self, qname: &'a str) -> Option<&'a str> { - self.get_ns(qname) - .map(|(_, host)| host) - .next() + self.get_ns(qname).map(|(_, host)| host).next() } } - #[cfg(test)] mod tests { use super::*; @@ -164,10 +164,12 @@ mod tests { }; packet.resources.push(resource); - packet.write(&mut buffer).expect("Failed to write DNS packet"); + packet + .write(&mut buffer) + .expect("Failed to write DNS packet"); assert!(buffer.seek(0).is_ok()); - + let read_packet = DnsPacket::from_buffer(&mut buffer).expect("Failed to read DNS packet"); // Validate the read packet @@ -232,7 +234,10 @@ mod tests { let random_a = packet.get_random_a(); assert!(random_a.is_some()); let random_a = random_a.unwrap(); - assert!(random_a == Ipv4Addr::new(216, 58, 211, 142) || random_a == Ipv4Addr::new(93, 184, 216, 34)); + assert!( + random_a == Ipv4Addr::new(216, 58, 211, 142) + || random_a == Ipv4Addr::new(93, 184, 216, 34) + ); } #[test] diff --git a/src/dns_question.rs b/src/dns_question.rs index 00b719a..13fd4d3 100644 --- a/src/dns_question.rs +++ b/src/dns_question.rs @@ -11,8 +11,8 @@ pub struct DnsQuestion { impl DnsQuestion { pub fn new(name: String, qtype: QueryType) -> DnsQuestion { DnsQuestion { - name: name, - qtype: qtype, + name, + qtype, } } /// Read Dns question. @@ -42,7 +42,7 @@ mod tests { #[test] fn test_read_question_from_buffer() { let mut buf = BytePacketBuffer::new(); - + buf.buf[0] = 6; buf.buf[1..7].copy_from_slice(b"google"); buf.buf[7] = 3; @@ -88,4 +88,4 @@ mod tests { assert_eq!(question.name, read_question.name); assert_eq!(question.qtype, read_question.qtype); } -} \ No newline at end of file +} diff --git a/src/dns_record.rs b/src/dns_record.rs index dfec68c..789763c 100644 --- a/src/dns_record.rs +++ b/src/dns_record.rs @@ -1,4 +1,7 @@ -use std::{io, net::{Ipv4Addr, Ipv6Addr}}; +use std::{ + io, + net::{Ipv4Addr, Ipv6Addr}, +}; use crate::{byte_packet_buffer::BytePacketBuffer, QueryType}; @@ -61,9 +64,9 @@ impl DnsRecord { ); Ok(DnsRecord::A { - domain: domain, - addr: addr, - ttl: ttl, + domain, + addr, + ttl, }) } QueryType::AAAA => { @@ -83,9 +86,9 @@ impl DnsRecord { ); Ok(DnsRecord::AAAA { - domain: domain, - addr: addr, - ttl: ttl, + domain, + addr, + ttl, }) } QueryType::NS => { @@ -93,9 +96,9 @@ impl DnsRecord { buffer.read_qname(&mut ns)?; Ok(DnsRecord::NS { - domain: domain, + domain, host: ns, - ttl: ttl, + ttl, }) } QueryType::CNAME => { @@ -103,9 +106,9 @@ impl DnsRecord { buffer.read_qname(&mut cname)?; Ok(DnsRecord::CNAME { - domain: domain, + domain, host: cname, - ttl: ttl, + ttl, }) } QueryType::MX => { @@ -114,20 +117,20 @@ impl DnsRecord { buffer.read_qname(&mut mx)?; Ok(DnsRecord::MX { - domain: domain, + domain, priority: priority, host: mx, - ttl: ttl, + ttl, }) } QueryType::UNKNOWN(_) => { buffer.step(data_len as usize)?; Ok(DnsRecord::UNKNOWN { - domain: domain, + domain, qtype: qtype_num, data_len: data_len, - ttl: ttl, + ttl, }) } } @@ -320,7 +323,6 @@ mod tests { assert_eq!(record, read_record); } - #[test] fn test_read_aaaa_record() { let mut buf = BytePacketBuffer::new(); @@ -341,8 +343,8 @@ mod tests { buf.buf[20] = 0x00; buf.buf[21] = 0x10; // data_len = 16 (IPv6 length) buf.buf[22..38].copy_from_slice(&[ - 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x10, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, + 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, ]); let record = DnsRecord::read(&mut buf).expect("Failed to read AAAA record"); @@ -350,10 +352,10 @@ mod tests { match record { DnsRecord::AAAA { domain, addr, ttl } => { assert_eq!(domain, "google.com"); - assert_eq!(addr, Ipv6Addr::new( - 0x2001, 0xdb8, 0x0010, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0001, - )); + assert_eq!( + addr, + Ipv6Addr::new(0x2001, 0xdb8, 0x0010, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001,) + ); assert_eq!(ttl, 293); } _ => panic!("Record is not of type AAAA"), @@ -365,8 +367,7 @@ mod tests { let record = DnsRecord::AAAA { domain: "google.com".to_string(), addr: Ipv6Addr::new( - 0x2001, 0xdb8, 0x0010, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0001, + 0x2001, 0xdb8, 0x0010, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, ), ttl: 293, }; @@ -389,10 +390,13 @@ mod tests { assert_eq!(buf.buf[19], 0x25); // ttl = 293 assert_eq!(buf.buf[20], 0x00); assert_eq!(buf.buf[21], 0x10); // data_len = 16 - assert_eq!(&buf.buf[22..38], &[ - 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x10, 0x00, 0x00, - 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x01, - ]); + assert_eq!( + &buf.buf[22..38], + &[ + 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + ] + ); } #[test] @@ -400,8 +404,7 @@ mod tests { let record = DnsRecord::AAAA { domain: "google.com".to_string(), addr: Ipv6Addr::new( - 0x2001, 0xdb8, 0x0010, 0x0000, - 0x0000, 0x0000, 0x0000, 0x0001, + 0x2001, 0xdb8, 0x0010, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, ), ttl: 293, }; @@ -457,7 +460,9 @@ mod tests { }; let mut buf = BytePacketBuffer::new(); - record.write(&mut buf).expect("Failed to write CNAME record"); + record + .write(&mut buf) + .expect("Failed to write CNAME record"); assert_eq!(buf.buf[0], 6); assert_eq!(&buf.buf[1..7], b"google"); @@ -487,7 +492,9 @@ mod tests { }; let mut buf = BytePacketBuffer::new(); - record.write(&mut buf).expect("Failed to write CNAME record"); + record + .write(&mut buf) + .expect("Failed to write CNAME record"); assert!(buf.seek(0).is_ok()); let read_record = DnsRecord::read(&mut buf).expect("Failed to read CNAME record"); @@ -524,7 +531,12 @@ mod tests { let record = DnsRecord::read(&mut buf).expect("Failed to read MX record"); match record { - DnsRecord::MX { domain, priority, host, ttl } => { + DnsRecord::MX { + domain, + priority, + host, + ttl, + } => { assert_eq!(domain, "google.com"); assert_eq!(priority, 10); assert_eq!(host, "mx1.com"); @@ -672,5 +684,4 @@ mod tests { assert_eq!(record, read_record); } - -} \ No newline at end of file +} diff --git a/src/lib.rs b/src/lib.rs index 44b14f1..b93bb21 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -1,12 +1,17 @@ -use std::{io::{self}, net::{Ipv4Addr, UdpSocket}}; +use std::{ + io::{self}, + net::{Ipv4Addr, UdpSocket}, +}; mod byte_packet_buffer; mod dns_header; +mod dns_packet; mod dns_question; mod dns_record; mod query_types; -mod dns_packet; -use {byte_packet_buffer::BytePacketBuffer, dns_header::ResultCode, dns_question::DnsQuestion, query_types::QueryType, dns_packet::DnsPacket}; - +use { + byte_packet_buffer::BytePacketBuffer, dns_header::ResultCode, dns_packet::DnsPacket, + dns_question::DnsQuestion, query_types::QueryType, +}; fn recursive_lookup(qname: &str, qtype: QueryType) -> Result { // For now we're always starting with *a.root-servers.net*. @@ -63,7 +68,7 @@ fn lookup(qname: &str, qtype: QueryType, server: (Ipv4Addr, u16)) -> Result Result Result<(), io::Error> { let mut req_buffer = BytePacketBuffer::new(); - let (_, src) = socket.recv_from(&mut req_buffer.buf)?; - let mut request = DnsPacket::from_buffer(&mut req_buffer)?; let mut packet = DnsPacket::new(); @@ -92,9 +95,9 @@ pub fn handle_query(socket: &UdpSocket) -> Result<(), io::Error> { if let Some(question) = request.questions.pop() { println!("Received query: {:?}", question); - if let Ok(result) = recursive_lookup(&question.name, question.qtype) { - packet.questions.push(question.clone()); - packet.header.rescode = result.header.rescode; + if let Ok(result) = recursive_lookup(&question.name, question.qtype) { + packet.questions.push(question.clone()); + packet.header.rescode = result.header.rescode; for rec in result.answers { println!("Answer: {:?}", rec); @@ -124,4 +127,4 @@ pub fn handle_query(socket: &UdpSocket) -> Result<(), io::Error> { socket.send_to(data, src)?; Ok(()) -} \ No newline at end of file +} diff --git a/src/main.rs b/src/main.rs index 5f4bf93..4cc9ce4 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,7 +1,5 @@ use std::{io, net::UdpSocket}; - - fn main() -> Result<(), io::Error> { // Bind an UDP socket on port 2053 let socket = UdpSocket::bind(("0.0.0.0", 2053))?; @@ -10,7 +8,7 @@ fn main() -> Result<(), io::Error> { // requests is initiated. loop { match dnsserver_nabil::handle_query(&socket) { - Ok(_) => {}, + Ok(_) => {} Err(e) => eprintln!("An error occurred: {}", e), } } diff --git a/src/query_types.rs b/src/query_types.rs index ac993bd..dd2c47f 100644 --- a/src/query_types.rs +++ b/src/query_types.rs @@ -30,4 +30,4 @@ impl QueryType { _ => QueryType::UNKNOWN(num), } } -} \ No newline at end of file +} From 81649822637403efa48670acd46d95da5eccf9d1 Mon Sep 17 00:00:00 2001 From: nabil salah Date: Sat, 21 Dec 2024 22:24:54 +0200 Subject: [PATCH 04/12] fix: project fmted Signed-off-by: nabil salah --- src/dns_question.rs | 5 +---- src/dns_record.rs | 12 ++---------- 2 files changed, 3 insertions(+), 14 deletions(-) diff --git a/src/dns_question.rs b/src/dns_question.rs index 13fd4d3..eaacb2f 100644 --- a/src/dns_question.rs +++ b/src/dns_question.rs @@ -10,10 +10,7 @@ pub struct DnsQuestion { impl DnsQuestion { pub fn new(name: String, qtype: QueryType) -> DnsQuestion { - DnsQuestion { - name, - qtype, - } + DnsQuestion { name, qtype } } /// Read Dns question. pub fn read(&mut self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { diff --git a/src/dns_record.rs b/src/dns_record.rs index 789763c..20bb5dd 100644 --- a/src/dns_record.rs +++ b/src/dns_record.rs @@ -63,11 +63,7 @@ impl DnsRecord { ((raw_addr >> 0) & 0xFF) as u8, ); - Ok(DnsRecord::A { - domain, - addr, - ttl, - }) + Ok(DnsRecord::A { domain, addr, ttl }) } QueryType::AAAA => { let raw_addr1 = buffer.read_u32()?; @@ -85,11 +81,7 @@ impl DnsRecord { ((raw_addr4 >> 0) & 0xFFFF) as u16, ); - Ok(DnsRecord::AAAA { - domain, - addr, - ttl, - }) + Ok(DnsRecord::AAAA { domain, addr, ttl }) } QueryType::NS => { let mut ns = String::new(); From 0cb0312e530e3eabd5ef2b015a1705650ff62502 Mon Sep 17 00:00:00 2001 From: nabil salah Date: Sat, 21 Dec 2024 23:53:52 +0200 Subject: [PATCH 05/12] feat: app containrize Signed-off-by: nabil salah --- Cargo.toml | 3 +++ Dockerfile | 25 +++++++++++++++++++++++++ docker-compose.yml | 9 +++++++++ 3 files changed, 37 insertions(+) create mode 100644 Dockerfile create mode 100644 docker-compose.yml diff --git a/Cargo.toml b/Cargo.toml index ea6f41a..62000e0 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,3 +4,6 @@ version = "0.1.0" edition = "2021" [dependencies] + +[profile.release] +target = "x86_64-unknown-linux-musl" \ No newline at end of file diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..6808c47 --- /dev/null +++ b/Dockerfile @@ -0,0 +1,25 @@ +# Stage 1: Build Stage +FROM rust:slim as builder + +RUN rustup target add x86_64-unknown-linux-musl +RUN apt-get update && apt-get install -y musl-tools + +WORKDIR /src + +COPY . /src + +RUN cargo build --release --target x86_64-unknown-linux-musl + +FROM alpine:3.19 + +WORKDIR /app + +COPY --from=builder /src/target/x86_64-unknown-linux-musl/release/dnsserver-nabil /app/dnsserver-nabil +COPY --from=builder /src/Cargo.toml /app/Cargo.toml + +# Ensure the binary is executable +RUN chmod +x /app/dnsserver-nabil + +EXPOSE 2053/udp + +CMD ["./dnsserver-nabil"] diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..1df42ba --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,9 @@ +version: '3.8' + +services: + dns-server: + image: dns-server-image + build: . + ports: + - "2053:2053/udp" + restart: always From e6b53e3188606338072416aa1e6404510557f579 Mon Sep 17 00:00:00 2001 From: nabil salah Date: Sun, 22 Dec 2024 00:04:53 +0200 Subject: [PATCH 06/12] update: makefile to support run using docker --- makefile | 14 +++++++++++++- 1 file changed, 13 insertions(+), 1 deletion(-) diff --git a/makefile b/makefile index 7109e55..0951beb 100644 --- a/makefile +++ b/makefile @@ -5,4 +5,16 @@ test: cargo test doc: - cargo doc --open \ No newline at end of file + cargo doc --open + +docker_build: + docker build -t dnsserver . + +docker_run: + docker run -p 2053:2053/udp dnsserver + +docker_compose_up: + docker-compose up + +docker_compose_down: + docker-compose down \ No newline at end of file From fb195e759787fb595457af1acb8398ed595c5b48 Mon Sep 17 00:00:00 2001 From: Nabil Salah Date: Sun, 22 Dec 2024 00:15:18 +0200 Subject: [PATCH 07/12] Update README.md: add video demo --- README.md | 10 +++++++++- 1 file changed, 9 insertions(+), 1 deletion(-) diff --git a/README.md b/README.md index e770c7b..8174853 100644 --- a/README.md +++ b/README.md @@ -58,7 +58,7 @@ make test Once the server is running, you can send DNS queries using any DNS client or command-line tools like `dig` or `nslookup`. For example: ```bash -dig @localhost -p 2053 example.com +dig @localhost -p 2053 google.com ``` ## Documentation @@ -69,6 +69,14 @@ This project is thoroughly documented. You can find detailed explanations of the make doc ``` +## Video demo + + + +https://github.com/user-attachments/assets/c58013ec-6b91-48ef-b89c-2653a87b15da + + + ## Contributing Contributions are welcome! If you have suggestions for improvements or new features, feel free to open an issue or submit a pull request. From ca76b3fe80435976a5bbb1637a2f848151316179 Mon Sep 17 00:00:00 2001 From: nabil salah Date: Mon, 23 Dec 2024 17:52:55 +0200 Subject: [PATCH 08/12] feat: add use of command line args Signed-off-by: nabil salah --- Cargo.lock | 230 ++++++++++++++++++++++++++++++++++++++++++++++++++++ Cargo.toml | 3 +- README.md | 13 +++ src/lib.rs | 21 +++-- src/main.rs | 21 ++++- 5 files changed, 278 insertions(+), 10 deletions(-) diff --git a/Cargo.lock b/Cargo.lock index 1179058..1563e94 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -2,6 +2,236 @@ # It is not intended for manual editing. version = 3 +[[package]] +name = "anstream" +version = "0.6.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" + +[[package]] +name = "anstyle-parse" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" +dependencies = [ + "anstyle", + "windows-sys", +] + +[[package]] +name = "clap" +version = "4.5.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" + +[[package]] +name = "colorchoice" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" + [[package]] name = "dnsserver-nabil" version = "0.1.0" +dependencies = [ + "clap", +] + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + +[[package]] +name = "proc-macro2" +version = "1.0.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "syn" +version = "2.0.90" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "919d3b74a5dd0ccd15aeb8f93e7006bd9e14c295087c9896a110f490752bcf31" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" + +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" diff --git a/Cargo.toml b/Cargo.toml index 62000e0..ad56c4e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -4,6 +4,7 @@ version = "0.1.0" edition = "2021" [dependencies] +clap = { version = "4.5.23", features = ["derive"] } [profile.release] -target = "x86_64-unknown-linux-musl" \ No newline at end of file +target = "x86_64-unknown-linux-musl" diff --git a/README.md b/README.md index 8174853..c22b262 100644 --- a/README.md +++ b/README.md @@ -35,6 +35,19 @@ A simple and efficient DNS server implemented in Rust, designed to handle DNS qu cargo build ``` +### Command Line Options +This is the result of running hermes -h +```bash +A simple DNS server application + +Usage: dnsserver-nabil [OPTIONS] + +Options: + -p, --port Port to bind the UDP socket [default: 2053] + -f, --forward-ip forward replies to specified dns server + -h, --help Print help + -V, --version Print version +``` ### Running the Server To run the DNS server, use the following command: diff --git a/src/lib.rs b/src/lib.rs index b93bb21..79d84ca 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -13,9 +13,18 @@ use { dns_question::DnsQuestion, query_types::QueryType, }; -fn recursive_lookup(qname: &str, qtype: QueryType) -> Result { - // For now we're always starting with *a.root-servers.net*. - let mut ns = "198.41.0.4".parse::().unwrap(); +fn recursive_lookup( + qname: &str, + qtype: QueryType, + root_ip: &Option, +) -> Result { + // For now we're always starting with *a.root-servers.net* if no forwarding ip specified. + let mut ns = match root_ip { + Some(ip) => ip, + None => "198.41.0.4", + } + .parse::() + .unwrap(); loop { println!("attempting lookup of {:?} {} with ns {}", qtype, qname, ns); @@ -44,7 +53,7 @@ fn recursive_lookup(qname: &str, qtype: QueryType) -> Result return Ok(response), }; - let recursive_response = recursive_lookup(&new_ns_name, QueryType::A)?; + let recursive_response = recursive_lookup(&new_ns_name, QueryType::A, root_ip)?; if let Some(new_ns) = recursive_response.get_random_a() { ns = new_ns; @@ -80,7 +89,7 @@ fn lookup(qname: &str, qtype: QueryType, server: (Ipv4Addr, u16)) -> Result Result<(), io::Error> { +pub fn handle_query(socket: &UdpSocket, root_ip: &Option) -> Result<(), io::Error> { let mut req_buffer = BytePacketBuffer::new(); let (_, src) = socket.recv_from(&mut req_buffer.buf)?; @@ -95,7 +104,7 @@ pub fn handle_query(socket: &UdpSocket) -> Result<(), io::Error> { if let Some(question) = request.questions.pop() { println!("Received query: {:?}", question); - if let Ok(result) = recursive_lookup(&question.name, question.qtype) { + if let Ok(result) = recursive_lookup(&question.name, question.qtype, root_ip) { packet.questions.push(question.clone()); packet.header.rescode = result.header.rescode; diff --git a/src/main.rs b/src/main.rs index 4cc9ce4..08c78a5 100644 --- a/src/main.rs +++ b/src/main.rs @@ -1,13 +1,28 @@ +use clap::Parser; use std::{io, net::UdpSocket}; +#[derive(Parser)] +#[command(name = "DnsServerApp")] +#[command(version = "1.0")] +#[command(about = "A simple DNS server application", long_about = None)] +struct Args { + /// Port to bind the UDP socket + #[arg(short, long, default_value_t = 2053)] + port: u16, + /// forward replies to specified dns server + #[arg(short, long, default_value = None)] + forward_ip: Option, +} + fn main() -> Result<(), io::Error> { - // Bind an UDP socket on port 2053 - let socket = UdpSocket::bind(("0.0.0.0", 2053))?; + let args = Args::parse(); + // Bind an UDP socket on port + let socket = UdpSocket::bind(("0.0.0.0", args.port))?; // For now, queries are handled sequentially, so an infinite loop for servicing // requests is initiated. loop { - match dnsserver_nabil::handle_query(&socket) { + match dnsserver_nabil::handle_query(&socket, &args.forward_ip) { Ok(_) => {} Err(e) => eprintln!("An error occurred: {}", e), } From 2dc082a2c1b157046dc67accd2fccf1bb6e47b15 Mon Sep 17 00:00:00 2001 From: Nabil Salah Date: Mon, 23 Dec 2024 18:06:16 +0200 Subject: [PATCH 09/12] Update README.md to use docker compose --- README.md | 10 ++++++++++ 1 file changed, 10 insertions(+) diff --git a/README.md b/README.md index c22b262..d0825cb 100644 --- a/README.md +++ b/README.md @@ -56,6 +56,16 @@ To run the DNS server, use the following command: make run ``` +### Running the Server using docker compose + +To run the DNS server but insure that you have docker installed on your machine [guidance](https://docs.docker.com/compose/install/). + +then use the following command : + +```bash +make docker_compose_up +``` + ### Running Tests To ensure everything is working correctly, you can run the test suite with: From 08c76e524060253ec4402ec37f82530dcbc3c264 Mon Sep 17 00:00:00 2001 From: nabil salah Date: Mon, 23 Dec 2024 21:27:23 +0200 Subject: [PATCH 10/12] refactor: remove unused macros Signed-off-by: nabil salah --- src/dns_question.rs | 2 +- src/dns_record.rs | 2 +- src/query_types.rs | 2 +- 3 files changed, 3 insertions(+), 3 deletions(-) diff --git a/src/dns_question.rs b/src/dns_question.rs index eaacb2f..f2ad13b 100644 --- a/src/dns_question.rs +++ b/src/dns_question.rs @@ -2,7 +2,7 @@ use std::io; use crate::{byte_packet_buffer::BytePacketBuffer, query_types::QueryType}; -#[derive(Debug, Clone, PartialEq, Eq)] +#[derive(Debug, Clone)] pub struct DnsQuestion { pub name: String, pub qtype: QueryType, diff --git a/src/dns_record.rs b/src/dns_record.rs index 20bb5dd..65a2f2e 100644 --- a/src/dns_record.rs +++ b/src/dns_record.rs @@ -5,7 +5,7 @@ use std::{ use crate::{byte_packet_buffer::BytePacketBuffer, QueryType}; -#[derive(Debug, Clone, PartialEq, Eq, Hash, PartialOrd, Ord)] +#[derive(Debug, Clone, PartialEq, Eq)] pub enum DnsRecord { UNKNOWN { domain: String, diff --git a/src/query_types.rs b/src/query_types.rs index dd2c47f..bb0f36a 100644 --- a/src/query_types.rs +++ b/src/query_types.rs @@ -1,4 +1,4 @@ -#[derive(PartialEq, Eq, Debug, Clone, Hash, Copy)] +#[derive(PartialEq, Eq, Debug, Clone, Copy)] pub enum QueryType { UNKNOWN(u16), A, // 1 From f9342246abb77d244a5388f9e6154a767c2a3484 Mon Sep 17 00:00:00 2001 From: nabil salah Date: Fri, 27 Dec 2024 14:15:13 +0200 Subject: [PATCH 11/12] fix:use Ephemeral port Signed-off-by: nabil salah --- Dockerfile | 1 - src/lib.rs | 2 +- 2 files changed, 1 insertion(+), 2 deletions(-) diff --git a/Dockerfile b/Dockerfile index 6808c47..042e67a 100644 --- a/Dockerfile +++ b/Dockerfile @@ -15,7 +15,6 @@ FROM alpine:3.19 WORKDIR /app COPY --from=builder /src/target/x86_64-unknown-linux-musl/release/dnsserver-nabil /app/dnsserver-nabil -COPY --from=builder /src/Cargo.toml /app/Cargo.toml # Ensure the binary is executable RUN chmod +x /app/dnsserver-nabil diff --git a/src/lib.rs b/src/lib.rs index 79d84ca..0ffa0d9 100644 --- a/src/lib.rs +++ b/src/lib.rs @@ -64,7 +64,7 @@ fn recursive_lookup( } fn lookup(qname: &str, qtype: QueryType, server: (Ipv4Addr, u16)) -> Result { - let socket = UdpSocket::bind(("0.0.0.0", 43210))?; + let socket = UdpSocket::bind(("0.0.0.0", 49152))?; let mut packet = DnsPacket::new(); From 81de2bd3577ddcac37530f2fefdd6d3226d41706 Mon Sep 17 00:00:00 2001 From: nabil salah Date: Fri, 27 Dec 2024 15:43:02 +0200 Subject: [PATCH 12/12] fix:docker compose ambiguous start Signed-off-by: nabil salah --- src/main.rs | 1 + 1 file changed, 1 insertion(+) diff --git a/src/main.rs b/src/main.rs index 08c78a5..c5681a9 100644 --- a/src/main.rs +++ b/src/main.rs @@ -18,6 +18,7 @@ fn main() -> Result<(), io::Error> { let args = Args::parse(); // Bind an UDP socket on port let socket = UdpSocket::bind(("0.0.0.0", args.port))?; + println!("Server listening on port {}", args.port); // For now, queries are handled sequentially, so an infinite loop for servicing // requests is initiated.