diff --git a/.github/workflows/rust.yaml b/.github/workflows/rust.yaml new file mode 100644 index 0000000..2fe10ab --- /dev/null +++ b/.github/workflows/rust.yaml @@ -0,0 +1,38 @@ +name: Rust + +on: [push] + +jobs: + build: + name: Test-Clippy-Build + runs-on: ubuntu-22.04 + steps: + - uses: actions/checkout@v3 + name: Checkout code + with: + fetch-depth: 1 + - uses: actions-rs/toolchain@v1 + name: Install toolchain + with: + toolchain: stable + - uses: actions-rs/cargo@v1 + name: Check formatting + with: + command: fmt + args: -- --check + - uses: actions-rs/cargo@v1 + name: Run clippy + with: + command: clippy + + - uses: actions-rs/cargo@v1 + name: Run tests + with: + command: test + + - uses: actions-rs/cargo@v1 + name: Build + with: + toolchain: stable + command: build + args: --release diff --git a/.gitignore b/.gitignore new file mode 100644 index 0000000..1acac83 --- /dev/null +++ b/.gitignore @@ -0,0 +1,5 @@ +/target +query_packet.txt +response_packet.txt +.vscode +Authoritative-DNS-server.png \ No newline at end of file diff --git a/Cargo.lock b/Cargo.lock new file mode 100644 index 0000000..1563e94 --- /dev/null +++ b/Cargo.lock @@ -0,0 +1,237 @@ +# This file is automatically @generated by Cargo. +# It is not intended for manual editing. +version = 3 + +[[package]] +name = "anstream" +version = "0.6.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8acc5369981196006228e28809f761875c0327210a891e941f4c683b3a99529b" +dependencies = [ + "anstyle", + "anstyle-parse", + "anstyle-query", + "anstyle-wincon", + "colorchoice", + "is_terminal_polyfill", + "utf8parse", +] + +[[package]] +name = "anstyle" +version = "1.0.10" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "55cc3b69f167a1ef2e161439aa98aed94e6028e5f9a59be9a6ffb47aef1651f9" + +[[package]] +name = "anstyle-parse" +version = "0.2.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3b2d16507662817a6a20a9ea92df6652ee4f94f914589377d69f3b21bc5798a9" +dependencies = [ + "utf8parse", +] + +[[package]] +name = "anstyle-query" +version = "1.1.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "79947af37f4177cfead1110013d678905c37501914fba0efea834c3fe9a8d60c" +dependencies = [ + "windows-sys", +] + +[[package]] +name = "anstyle-wincon" +version = "3.0.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2109dbce0e72be3ec00bed26e6a7479ca384ad226efdd66db8fa2e3a38c83125" +dependencies = [ + "anstyle", + "windows-sys", +] + +[[package]] +name = "clap" +version = "4.5.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "3135e7ec2ef7b10c6ed8950f0f792ed96ee093fa088608f1c76e569722700c84" +dependencies = [ + "clap_builder", + "clap_derive", +] + +[[package]] +name = "clap_builder" +version = "4.5.23" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "30582fc632330df2bd26877bde0c1f4470d57c582bbc070376afcd04d8cb4838" +dependencies = [ + "anstream", + "anstyle", + "clap_lex", + "strsim", +] + +[[package]] +name = "clap_derive" +version = "4.5.18" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "4ac6a0c7b1a9e9a5186361f67dfa1b88213572f427fb9ab038efb2bd8c582dab" +dependencies = [ + "heck", + "proc-macro2", + "quote", + "syn", +] + +[[package]] +name = "clap_lex" +version = "0.7.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "f46ad14479a25103f283c0f10005961cf086d8dc42205bb44c46ac563475dca6" + +[[package]] +name = "colorchoice" +version = "1.0.3" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "5b63caa9aa9397e2d9480a9b13673856c78d8ac123288526c37d7839f2a86990" + +[[package]] +name = "dnsserver-nabil" +version = "0.1.0" +dependencies = [ + "clap", +] + +[[package]] +name = "heck" +version = "0.5.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "2304e00983f87ffb38b55b444b5e3b60a884b5d30c0fca7d82fe33449bbe55ea" + +[[package]] +name = "is_terminal_polyfill" +version = "1.70.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7943c866cc5cd64cbc25b2e01621d07fa8eb2a1a23160ee81ce38704e97b8ecf" + +[[package]] +name = "proc-macro2" +version = "1.0.92" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "37d3544b3f2748c54e147655edb5025752e2303145b5aefb3c3ea2c78b973bb0" +dependencies = [ + "unicode-ident", +] + +[[package]] +name = "quote" +version = "1.0.37" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "b5b9d34b8991d19d98081b46eacdd8eb58c6f2b201139f7c5f643cc155a633af" +dependencies = [ + "proc-macro2", +] + +[[package]] +name = "strsim" +version = "0.11.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7da8b5736845d9f2fcb837ea5d9e2628564b3b043a70948a3f0b778838c5fb4f" + +[[package]] +name = "syn" +version = "2.0.90" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "919d3b74a5dd0ccd15aeb8f93e7006bd9e14c295087c9896a110f490752bcf31" +dependencies = [ + "proc-macro2", + "quote", + "unicode-ident", +] + +[[package]] +name = "unicode-ident" +version = "1.0.14" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "adb9e6ca4f869e1180728b7950e35922a7fc6397f7b641499e8f3ef06e50dc83" + +[[package]] +name = "utf8parse" +version = "0.2.2" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "06abde3611657adf66d383f00b093d7faecc7fa57071cce2578660c9f1010821" + +[[package]] +name = "windows-sys" +version = "0.59.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "1e38bc4d79ed67fd075bcc251a1c39b32a1776bbe92e5bef1f0bf1f8c531853b" +dependencies = [ + "windows-targets", +] + +[[package]] +name = "windows-targets" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "9b724f72796e036ab90c1021d4780d4d3d648aca59e491e6b98e725b84e99973" +dependencies = [ + "windows_aarch64_gnullvm", + "windows_aarch64_msvc", + "windows_i686_gnu", + "windows_i686_gnullvm", + "windows_i686_msvc", + "windows_x86_64_gnu", + "windows_x86_64_gnullvm", + "windows_x86_64_msvc", +] + +[[package]] +name = "windows_aarch64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "32a4622180e7a0ec044bb555404c800bc9fd9ec262ec147edd5989ccd0c02cd3" + +[[package]] +name = "windows_aarch64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "09ec2a7bb152e2252b53fa7803150007879548bc709c039df7627cabbd05d469" + +[[package]] +name = "windows_i686_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "8e9b5ad5ab802e97eb8e295ac6720e509ee4c243f69d781394014ebfe8bbfa0b" + +[[package]] +name = "windows_i686_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "0eee52d38c090b3caa76c563b86c3a4bd71ef1a819287c19d586d7334ae8ed66" + +[[package]] +name = "windows_i686_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "240948bc05c5e7c6dabba28bf89d89ffce3e303022809e73deaefe4f6ec56c66" + +[[package]] +name = "windows_x86_64_gnu" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "147a5c80aabfbf0c7d901cb5895d1de30ef2907eb21fbbab29ca94c5b08b1a78" + +[[package]] +name = "windows_x86_64_gnullvm" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "24d5b23dc417412679681396f2b49f3de8c1473deb516bd34410872eff51ed0d" + +[[package]] +name = "windows_x86_64_msvc" +version = "0.52.6" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "589f6da84c646204747d1270a2a5661ea66ed1cced2631d546fdfb155959f9ec" diff --git a/Cargo.toml b/Cargo.toml new file mode 100644 index 0000000..ad56c4e --- /dev/null +++ b/Cargo.toml @@ -0,0 +1,10 @@ +[package] +name = "dnsserver-nabil" +version = "0.1.0" +edition = "2021" + +[dependencies] +clap = { version = "4.5.23", features = ["derive"] } + +[profile.release] +target = "x86_64-unknown-linux-musl" diff --git a/Dockerfile b/Dockerfile new file mode 100644 index 0000000..042e67a --- /dev/null +++ b/Dockerfile @@ -0,0 +1,24 @@ +# Stage 1: Build Stage +FROM rust:slim as builder + +RUN rustup target add x86_64-unknown-linux-musl +RUN apt-get update && apt-get install -y musl-tools + +WORKDIR /src + +COPY . /src + +RUN cargo build --release --target x86_64-unknown-linux-musl + +FROM alpine:3.19 + +WORKDIR /app + +COPY --from=builder /src/target/x86_64-unknown-linux-musl/release/dnsserver-nabil /app/dnsserver-nabil + +# Ensure the binary is executable +RUN chmod +x /app/dnsserver-nabil + +EXPOSE 2053/udp + +CMD ["./dnsserver-nabil"] diff --git a/README.md b/README.md index e3d77c5..d0825cb 100644 --- a/README.md +++ b/README.md @@ -1 +1,106 @@ -# dnsserver-nabil \ No newline at end of file +# dnsserver-nabil + +A simple and efficient DNS server implemented in Rust, designed to handle DNS queries and responses. This project includes documentation and tests to ensure functionality and reliability. + +## Features + +- **DNS Query Handling**: Supports various DNS query types, including A, AAAA, MX, NS, and CNAME. +- **Custom Implementation**: Fully implemented from scratch with a focus on clarity and maintainability. +- **Documentation**: Inline documentation explaining the code structure and logic. +- **Tests**: Comprehensive test suite covering all aspects of the DNS server functionality. + +## Getting Started + +### Prerequisites + +- Rust and Cargo installed on your machine. You can install Rust using [rustup](https://rustup.rs/). + +### Installation + +1. Clone the repository: + + ```bash + git clone https://github.com/codescalersinternships/dnsserver-nabil/tree/development + ``` + +2. Navigate to the project directory: + + ```bash + cd dnsserver-nabil + ``` + +3. Build the project: + + ```bash + cargo build + ``` + +### Command Line Options +This is the result of running hermes -h +```bash +A simple DNS server application + +Usage: dnsserver-nabil [OPTIONS] + +Options: + -p, --port Port to bind the UDP socket [default: 2053] + -f, --forward-ip forward replies to specified dns server + -h, --help Print help + -V, --version Print version +``` +### Running the Server + +To run the DNS server, use the following command: + +```bash +make run +``` + +### Running the Server using docker compose + +To run the DNS server but insure that you have docker installed on your machine [guidance](https://docs.docker.com/compose/install/). + +then use the following command : + +```bash +make docker_compose_up +``` + +### Running Tests + +To ensure everything is working correctly, you can run the test suite with: + +```bash +make test +``` + + + +## Usage + +Once the server is running, you can send DNS queries using any DNS client or command-line tools like `dig` or `nslookup`. For example: + +```bash +dig @localhost -p 2053 google.com +``` + +## Documentation + +This project is thoroughly documented. You can find detailed explanations of the code and its functionality within the source files. Additionally, you can generate and view the documentation using: + +```bash +make doc +``` + +## Video demo + + + +https://github.com/user-attachments/assets/c58013ec-6b91-48ef-b89c-2653a87b15da + + + +## Contributing + +Contributions are welcome! If you have suggestions for improvements or new features, feel free to open an issue or submit a pull request. + diff --git a/docker-compose.yml b/docker-compose.yml new file mode 100644 index 0000000..1df42ba --- /dev/null +++ b/docker-compose.yml @@ -0,0 +1,9 @@ +version: '3.8' + +services: + dns-server: + image: dns-server-image + build: . + ports: + - "2053:2053/udp" + restart: always diff --git a/makefile b/makefile new file mode 100644 index 0000000..0951beb --- /dev/null +++ b/makefile @@ -0,0 +1,20 @@ +run: + cargo run + +test: + cargo test + +doc: + cargo doc --open + +docker_build: + docker build -t dnsserver . + +docker_run: + docker run -p 2053:2053/udp dnsserver + +docker_compose_up: + docker-compose up + +docker_compose_down: + docker-compose down \ No newline at end of file diff --git a/src/byte_packet_buffer.rs b/src/byte_packet_buffer.rs new file mode 100644 index 0000000..a918a1c --- /dev/null +++ b/src/byte_packet_buffer.rs @@ -0,0 +1,377 @@ +use std::io::{self, Error}; + +pub struct BytePacketBuffer { + pub buf: [u8; 512], + pub pos: usize, +} + +impl BytePacketBuffer { + /// This gives us a fresh buffer for holding the packet contents, and a + /// field for keeping track of where we are. + pub fn new() -> BytePacketBuffer { + BytePacketBuffer { + buf: [0; 512], + pos: 0, + } + } + + /// Current position within buffer + pub fn pos(&self) -> usize { + self.pos + } + + /// Step the buffer position forward a specific number of steps + pub fn step(&mut self, steps: usize) -> Result<(), io::Error> { + if self.pos + steps >= 512 { + return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); + } + self.pos += steps; + Ok(()) + } + + /// Change the buffer position + pub fn seek(&mut self, pos: usize) -> Result<(), io::Error> { + if pos >= 512 { + return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); + } + self.pos = pos; + Ok(()) + } + + /// Read a single byte and move the position one step forward + pub fn read(&mut self) -> Result { + if self.pos >= 512 { + return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); + } + let res = self.buf[self.pos]; + self.pos += 1; + Ok(res) + } + + /// Get a single byte, without changing the buffer position + pub fn get(&self, pos: usize) -> Result { + if pos >= 512 { + return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); + } + Ok(self.buf[pos]) + } + + /// Get a range of bytes + pub fn get_range(&self, start: usize, len: usize) -> Result<&[u8], io::Error> { + if start + len > 512 { + return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); + } + Ok(&self.buf[start..(start + len)]) + } + + /// Read two bytes, stepping two steps forward + pub fn read_u16(&mut self) -> Result { + let res = ((self.read()? as u16) << 8) | (self.read()? as u16); + Ok(res) + } + + /// Read four bytes, stepping four steps forward + pub fn read_u32(&mut self) -> Result { + let res = ((self.read()? as u32) << 24) + | ((self.read()? as u32) << 16) + | ((self.read()? as u32) << 8) + | ((self.read()? as u32) << 0); + + Ok(res) + } + + /// Read a qname + /// + /// Will take something like [3]www[6]google[3]com[0] and append + /// www.google.com to outstr. + /// + /// also it handle jemps . + pub fn read_qname(&mut self, outstr: &mut String) -> Result<(), io::Error> { + let mut pos = self.pos(); + let mut jumped = false; + let max_jumps = 5; + let mut jumps_performed = 0; + let mut delim = ""; + loop { + if jumps_performed > max_jumps { + return Err(Error::new( + io::ErrorKind::UnexpectedEof, + format!("Limit of {} jumps exceeded", max_jumps), + )); + } + + let len = self.get(pos)?; + + if (len & 0xC0) == 0xC0 { + if !jumped { + self.seek(pos + 2)? + } + + let b2 = self.get(pos + 1)? as u16; + let offset = (((len as u16) ^ 0xC0) << 8) | b2; + pos = offset as usize; + + jumped = true; + jumps_performed += 1; + } else { + pos += 1; + + if len == 0 { + break; + } + outstr.push_str(delim); + let str_buffer = self.get_range(pos, len as usize)?; + outstr.push_str(&String::from_utf8_lossy(str_buffer).to_lowercase()); + + delim = "."; + pos += len as usize; + } + } + if !jumped { + self.seek(pos)?; + } + + Ok(()) + } + + /// Write a single byte and move the position one step forward + pub fn write(&mut self, val: u8) -> Result<(), io::Error> { + if self.pos >= 512 { + return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); + } + self.buf[self.pos] = val; + self.pos += 1; + Ok(()) + } + + /// Read a single byte and move the position one step forward + pub fn write_u8(&mut self, val: u8) -> Result<(), io::Error> { + self.write(val)?; + + Ok(()) + } + + /// Read a two bytes and move the position two steps forward + pub fn write_u16(&mut self, val: u16) -> Result<(), io::Error> { + self.write((val >> 8) as u8)?; + self.write((val & 0xFF) as u8)?; + + Ok(()) + } + + /// Write four bytes, stepping four steps forward + pub fn write_u32(&mut self, val: u32) -> Result<(), io::Error> { + self.write(((val >> 24) & 0xFF) as u8)?; + self.write(((val >> 16) & 0xFF) as u8)?; + self.write(((val >> 8) & 0xFF) as u8)?; + self.write(((val >> 0) & 0xFF) as u8)?; + + Ok(()) + } + + /// Write a qname + /// + /// Will take something like www.google.com + /// + /// dots are the separator used + pub fn write_qname(&mut self, qname: &str) -> Result<(), io::Error> { + for label in qname.split('.') { + let len = label.len(); + if len > 0x3f { + return Err(Error::new( + io::ErrorKind::UnexpectedEof, + "Single label exceeds 63 characters of length", + )); + } + + self.write_u8(len as u8)?; + for b in label.as_bytes() { + self.write_u8(*b)?; + } + } + + self.write_u8(0)?; + + Ok(()) + } + + /// Write a single byte, without changing the buffer position + pub fn set(&mut self, pos: usize, val: u8) -> Result<(), io::Error> { + if pos >= 512 { + return Err(Error::new(io::ErrorKind::UnexpectedEof, "End of buffer")); + } + self.buf[pos] = val; + + Ok(()) + } + + /// Get two bytes, without changing the buffer position + pub fn set_u16(&mut self, pos: usize, val: u16) -> Result<(), io::Error> { + self.set(pos, (val >> 8) as u8)?; + self.set(pos + 1, (val & 0xFF) as u8)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + #[test] + fn test_initial_position() { + let buf = BytePacketBuffer::new(); + assert_eq!(buf.pos(), 0); + } + + #[test] + fn test_pos_step_seek_limit() { + let mut buf = BytePacketBuffer::new(); + assert!(buf.step(511).is_ok()); + assert_eq!(buf.pos(), 511); + assert!(buf.step(1).is_err()); + assert!(buf.seek(511).is_ok()); + assert!(buf.seek(512).is_err()); + } + + #[test] + fn test_read_write_limits() { + let mut buf = BytePacketBuffer::new(); + + // Test reading at limit + assert!(buf.seek(511).is_ok()); + assert!(buf.read().is_ok()); + assert!(buf.read().is_err()); + + // Test writing at limit + assert!(buf.seek(511).is_ok()); + assert!(buf.write(0xFF).is_ok()); + assert!(buf.write(0xFF).is_err()); + assert_eq!(buf.buf[511], 0xFF); + } + + #[test] + fn test_read_u16_u32() { + let mut buf = BytePacketBuffer::new(); + buf.buf[0] = 0x12; + buf.buf[1] = 0x34; + buf.buf[2] = 0x56; + buf.buf[3] = 0x78; + + assert_eq!(buf.read_u16().unwrap(), 0x1234); + assert!(buf.seek(0).is_ok()); + assert_eq!(buf.read_u32().unwrap(), 0x12345678); + } + + #[test] + fn test_set_get() { + let mut buf = BytePacketBuffer::new(); + assert!(buf.set(511, 0xAB).is_ok()); + assert_eq!(buf.get(511).unwrap(), 0xAB); + assert!(buf.set(512, 0xAB).is_err()); + } + + #[test] + fn test_read_qname_no_jump() { + let mut buf = BytePacketBuffer::new(); + buf.buf[0] = 3; + buf.buf[1] = b'w'; + buf.buf[2] = b'w'; + buf.buf[3] = b'w'; + buf.buf[4] = 6; + buf.buf[5] = b'g'; + buf.buf[6] = b'o'; + buf.buf[7] = b'o'; + buf.buf[8] = b'g'; + buf.buf[9] = b'l'; + buf.buf[10] = b'e'; + buf.buf[11] = 3; + buf.buf[12] = b'c'; + buf.buf[13] = b'o'; + buf.buf[14] = b'm'; + buf.buf[15] = 0; + + let mut outstr = String::new(); + assert!(buf.read_qname(&mut outstr).is_ok()); + assert_eq!(outstr, "www.google.com"); + } + + #[test] + fn test_read_qname_with_single_jump() { + let mut buf = BytePacketBuffer::new(); + buf.buf[20] = 3; + buf.buf[21] = b'w'; + buf.buf[22] = b'w'; + buf.buf[23] = b'w'; + buf.buf[24] = 6; + buf.buf[25] = b'g'; + buf.buf[26] = b'o'; + buf.buf[27] = b'o'; + buf.buf[28] = b'g'; + buf.buf[29] = b'l'; + buf.buf[30] = b'e'; + buf.buf[31] = 3; + buf.buf[32] = b'c'; + buf.buf[33] = b'o'; + buf.buf[34] = b'm'; + buf.buf[35] = 0; + + buf.buf[0] = 0xC0; + buf.buf[1] = 20; + + let mut outstr = String::new(); + assert!(buf.read_qname(&mut outstr).is_ok()); + assert_eq!(outstr, "www.google.com"); + } + + #[test] + fn test_read_qname_with_multiple_jumps() { + let mut buf = BytePacketBuffer::new(); + // "google.com" at position 20 + buf.buf[20] = 6; + buf.buf[21] = b'g'; + buf.buf[22] = b'o'; + buf.buf[23] = b'o'; + buf.buf[24] = b'g'; + buf.buf[25] = b'l'; + buf.buf[26] = b'e'; + buf.buf[27] = 3; + buf.buf[28] = b'c'; + buf.buf[29] = b'o'; + buf.buf[30] = b'm'; + buf.buf[31] = 0; + + // "www" at position 50 + buf.buf[50] = 3; + buf.buf[51] = b'w'; + buf.buf[52] = b'w'; + buf.buf[53] = b'w'; + buf.buf[54] = 0xC0; // Pointer to "google.com" + buf.buf[55] = 20; + + // Pointer to "www" at start + buf.buf[0] = 0xC0; + buf.buf[1] = 50; + + let mut outstr = String::new(); + assert!(buf.read_qname(&mut outstr).is_ok()); + assert_eq!(outstr, "www.google.com"); + } + + #[test] + fn test_write_qname_success() { + let mut buf = BytePacketBuffer::new(); + assert!(buf.write_qname("www.google.com").is_ok()); + + let mut outstr = String::new(); + assert!(buf.seek(0).is_ok()); + assert!(buf.read_qname(&mut outstr).is_ok()); + assert_eq!(outstr, "www.google.com"); + } + + #[test] + fn test_write_qname_overflow() { + let mut buf = BytePacketBuffer::new(); + assert!(buf.seek(510).is_ok()); + assert!(buf.write_qname("a").is_err()); + } +} diff --git a/src/dns_header.rs b/src/dns_header.rs new file mode 100644 index 0000000..d4468a9 --- /dev/null +++ b/src/dns_header.rs @@ -0,0 +1,257 @@ +use std::io; + +use crate::byte_packet_buffer::BytePacketBuffer; + +#[derive(Copy, Clone, Debug, PartialEq, Eq)] +pub enum ResultCode { + NOERROR = 0, + FORMERR = 1, + SERVFAIL = 2, + NXDOMAIN = 3, + NOTIMP = 4, + REFUSED = 5, +} + +impl ResultCode { + pub fn from_num(num: u8) -> ResultCode { + match num { + 1 => ResultCode::FORMERR, + 2 => ResultCode::SERVFAIL, + 3 => ResultCode::NXDOMAIN, + 4 => ResultCode::NOTIMP, + 5 => ResultCode::REFUSED, + 0 | _ => ResultCode::NOERROR, + } + } +} + +#[derive(Clone, Debug)] +pub struct DnsHeader { + pub id: u16, // 16 bits + + pub recursion_desired: bool, // 1 bit + pub truncated_message: bool, // 1 bit + pub authoritative_answer: bool, // 1 bit + pub opcode: u8, // 4 bits + pub response: bool, // 1 bit + + pub rescode: ResultCode, // 4 bits + pub checking_disabled: bool, // 1 bit + pub authed_data: bool, // 1 bit + pub z: bool, // 1 bit + pub recursion_available: bool, // 1 bit + + pub questions: u16, // 16 bits + pub answers: u16, // 16 bits + pub authoritative_entries: u16, // 16 bits + pub resource_entries: u16, // 16 bits +} + +impl DnsHeader { + pub fn new() -> DnsHeader { + DnsHeader { + id: 0, + + recursion_desired: false, + truncated_message: false, + authoritative_answer: false, + opcode: 0, + response: false, + + rescode: ResultCode::NOERROR, + checking_disabled: false, + authed_data: false, + z: false, + recursion_available: false, + + questions: 0, + answers: 0, + authoritative_entries: 0, + resource_entries: 0, + } + } + /// Read Dns packet header. + pub fn read(&mut self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { + self.id = buffer.read_u16()?; + + let mut flags = buffer.read()?; + self.recursion_desired = (flags & (1 << 0)) > 0; + self.truncated_message = (flags & (1 << 1)) > 0; + self.authoritative_answer = (flags & (1 << 2)) > 0; + self.opcode = (flags >> 3) & 0x0F; + self.response = (flags & (1 << 7)) > 0; + + flags = buffer.read()?; + self.rescode = ResultCode::from_num(flags & 0x0F); + self.checking_disabled = (flags & (1 << 4)) > 0; + self.authed_data = (flags & (1 << 5)) > 0; + self.z = (flags & (1 << 6)) > 0; + self.recursion_available = (flags & (1 << 7)) > 0; + + self.questions = buffer.read_u16()?; + self.answers = buffer.read_u16()?; + self.authoritative_entries = buffer.read_u16()?; + self.resource_entries = buffer.read_u16()?; + + Ok(()) + } + /// Write Dns packet header as bytes to buffer. + pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { + buffer.write_u16(self.id)?; + + buffer.write_u8( + (self.recursion_desired as u8) + | ((self.truncated_message as u8) << 1) + | ((self.authoritative_answer as u8) << 2) + | (self.opcode << 3) + | ((self.response as u8) << 7) as u8, + )?; + + buffer.write_u8( + (self.rescode as u8) + | ((self.checking_disabled as u8) << 4) + | ((self.authed_data as u8) << 5) + | ((self.z as u8) << 6) + | ((self.recursion_available as u8) << 7), + )?; + + buffer.write_u16(self.questions)?; + buffer.write_u16(self.answers)?; + buffer.write_u16(self.authoritative_entries)?; + buffer.write_u16(self.resource_entries)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_read_header_from_buffer() { + let mut buf = BytePacketBuffer::new(); + + buf.buf[0] = 0x12; + buf.buf[1] = 0x34; + + buf.buf[2] = 0b10010101; // flags: recursion_desired=1, truncated_message=0, authoritative_answer=1, opcode=2, response=1 + buf.buf[3] = 0b11100101; // flags: rescode=5, checking_disabled=0, authed_data=1, z=1, recursion_available=1 + + buf.buf[4] = 0x00; + buf.buf[5] = 0x01; // questions = 1 + buf.buf[6] = 0x00; + buf.buf[7] = 0x02; // answers = 2 + buf.buf[8] = 0x00; + buf.buf[9] = 0x03; // authoritative_entries = 3 + buf.buf[10] = 0x00; + buf.buf[11] = 0x04; // resource_entries = 4 + + let mut header = DnsHeader::new(); + assert!(header.read(&mut buf).is_ok()); + + assert_eq!(header.id, 0x1234); + assert_eq!(header.recursion_desired, true); + assert_eq!(header.truncated_message, false); + assert_eq!(header.authoritative_answer, true); + assert_eq!(header.opcode, 2); + assert_eq!(header.response, true); + assert_eq!(header.rescode, ResultCode::REFUSED); + assert_eq!(header.checking_disabled, false); + assert_eq!(header.authed_data, true); + assert_eq!(header.z, true); + assert_eq!(header.recursion_available, true); + assert_eq!(header.questions, 1); + assert_eq!(header.answers, 2); + assert_eq!(header.authoritative_entries, 3); + assert_eq!(header.resource_entries, 4); + } + + #[test] + fn test_write_header_to_buffer() { + let mut header = DnsHeader::new(); + header.id = 0x1234; + header.recursion_desired = true; + header.truncated_message = false; + header.authoritative_answer = true; + header.opcode = 2; + header.response = true; + header.rescode = ResultCode::REFUSED; + header.checking_disabled = false; + header.authed_data = true; + header.z = true; + header.recursion_available = true; + header.questions = 1; + header.answers = 2; + header.authoritative_entries = 3; + header.resource_entries = 4; + + let mut buf = BytePacketBuffer::new(); + assert!(header.write(&mut buf).is_ok()); + + assert_eq!(buf.buf[0], 0x12); + assert_eq!(buf.buf[1], 0x34); // id = 0x1234 + + assert_eq!(buf.buf[2], 0b10010101); + assert_eq!(buf.buf[3], 0b11100101); + + assert_eq!(buf.buf[4], 0x00); + assert_eq!(buf.buf[5], 0x01); // questions = 1 + assert_eq!(buf.buf[6], 0x00); + assert_eq!(buf.buf[7], 0x02); // answers = 2 + assert_eq!(buf.buf[8], 0x00); + assert_eq!(buf.buf[9], 0x03); // authoritative_entries = 3 + assert_eq!(buf.buf[10], 0x00); + assert_eq!(buf.buf[11], 0x04); // resource_entries = 4 + } + + #[test] + fn test_read_write() { + let mut header = DnsHeader::new(); + header.id = 0x4321; + header.recursion_desired = true; + header.truncated_message = true; + header.authoritative_answer = true; + header.opcode = 3; + header.response = false; + header.rescode = ResultCode::REFUSED; + header.checking_disabled = true; + header.authed_data = true; + header.z = false; + header.recursion_available = true; + header.questions = 10; + header.answers = 20; + header.authoritative_entries = 30; + header.resource_entries = 40; + + let mut buf = BytePacketBuffer::new(); + assert!(header.write(&mut buf).is_ok()); + + let mut read_header = DnsHeader::new(); + assert!(buf.seek(0).is_ok()); + assert!(read_header.read(&mut buf).is_ok()); + println!("{read_header:?}"); + + assert_eq!(header.id, read_header.id); + assert_eq!(header.recursion_desired, read_header.recursion_desired); + assert_eq!(header.truncated_message, read_header.truncated_message); + assert_eq!( + header.authoritative_answer, + read_header.authoritative_answer + ); + assert_eq!(header.opcode, read_header.opcode); + assert_eq!(header.response, read_header.response); + assert_eq!(header.rescode, read_header.rescode); + assert_eq!(header.checking_disabled, read_header.checking_disabled); + assert_eq!(header.authed_data, read_header.authed_data); + assert_eq!(header.z, read_header.z); + assert_eq!(header.recursion_available, read_header.recursion_available); + assert_eq!(header.questions, read_header.questions); + assert_eq!(header.answers, read_header.answers); + assert_eq!( + header.authoritative_entries, + read_header.authoritative_entries + ); + assert_eq!(header.resource_entries, read_header.resource_entries); + } +} diff --git a/src/dns_packet.rs b/src/dns_packet.rs new file mode 100644 index 0000000..d3fb2b6 --- /dev/null +++ b/src/dns_packet.rs @@ -0,0 +1,273 @@ +use std::{io, net::Ipv4Addr}; + +use crate::{ + byte_packet_buffer::BytePacketBuffer, dns_header::DnsHeader, dns_question::DnsQuestion, + dns_record::DnsRecord, query_types::QueryType, +}; + +#[derive(Clone, Debug)] +pub struct DnsPacket { + pub header: DnsHeader, + pub questions: Vec, + pub answers: Vec, + pub authorities: Vec, + pub resources: Vec, +} + +impl DnsPacket { + pub fn new() -> DnsPacket { + DnsPacket { + header: DnsHeader::new(), + questions: Vec::new(), + answers: Vec::new(), + authorities: Vec::new(), + resources: Vec::new(), + } + } + + /// Read complete Dns packet from BytePacketBuffer + pub fn from_buffer(buffer: &mut BytePacketBuffer) -> Result { + let mut result = DnsPacket::new(); + result.header.read(buffer)?; + + for _ in 0..result.header.questions { + let mut question = DnsQuestion::new("".to_string(), QueryType::UNKNOWN(0)); + question.read(buffer)?; + result.questions.push(question); + } + + for _ in 0..result.header.answers { + let rec = DnsRecord::read(buffer)?; + result.answers.push(rec); + } + for _ in 0..result.header.authoritative_entries { + let rec = DnsRecord::read(buffer)?; + result.authorities.push(rec); + } + for _ in 0..result.header.resource_entries { + let rec = DnsRecord::read(buffer)?; + result.resources.push(rec); + } + + Ok(result) + } + + /// Write complete Dns packet to BytePacketBuffer + pub fn write(&mut self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { + self.header.questions = self.questions.len() as u16; + self.header.answers = self.answers.len() as u16; + self.header.authoritative_entries = self.authorities.len() as u16; + self.header.resource_entries = self.resources.len() as u16; + + self.header.write(buffer)?; + + for question in &self.questions { + question.write(buffer)?; + } + for rec in &self.answers { + rec.write(buffer)?; + } + for rec in &self.authorities { + rec.write(buffer)?; + } + for rec in &self.resources { + rec.write(buffer)?; + } + + Ok(()) + } + + /// Get Random A to be able to pick a random A record from a packet. When we + /// get multiple IP's for a single name, it doesn't matter which one we + /// choose, so in those cases we can now pick one at random. + pub fn get_random_a(&self) -> Option { + self.answers + .iter() + .filter_map(|record| match record { + DnsRecord::A { addr, .. } => Some(*addr), + _ => None, + }) + .next() + } + + /// Get NS helper function which returns an iterator over all name servers in + /// the authorities section, represented as (domain, host) tuples + fn get_ns<'a>(&'a self, qname: &'a str) -> impl Iterator { + self.authorities + .iter() + .filter_map(|record| match record { + DnsRecord::NS { domain, host, .. } => Some((domain.as_str(), host.as_str())), + _ => None, + }) + .filter(move |(domain, _)| qname.ends_with(*domain)) + } + + /// Get Resolved NS + /// as We'll use the fact that name servers often bundle the corresponding + /// A records when replying to an NS query to implement a function that + /// returns the actual IP for an NS record if possible. + pub fn get_resolved_ns(&self, qname: &str) -> Option { + self.get_ns(qname) + .flat_map(|(_, host)| { + self.resources + .iter() + .filter_map(move |record| match record { + DnsRecord::A { domain, addr, .. } if domain == host => Some(addr), + _ => None, + }) + }) + .map(|addr| *addr) + .next() + } + + /// Get Unresolved NS a method for returning the host + /// name of an appropriate name server. + pub fn get_unresolved_ns<'a>(&'a self, qname: &'a str) -> Option<&'a str> { + self.get_ns(qname).map(|(_, host)| host).next() + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_write_and_read_dns_packet() { + let mut buffer = BytePacketBuffer::new(); + let mut packet = DnsPacket::new(); + + // Setup a DNS question + let question = DnsQuestion::new("google.com".to_string(), QueryType::A); + packet.questions.push(question); + + // Setup a DNS answer + let answer = DnsRecord::A { + domain: "google.com".to_string(), + addr: Ipv4Addr::new(216, 58, 211, 142), + ttl: 293, + }; + packet.answers.push(answer); + + // Setup a DNS authority + let authority = DnsRecord::NS { + domain: "google.com".to_string(), + host: "ns1.google.com".to_string(), + ttl: 293, + }; + packet.authorities.push(authority); + + // Setup a DNS resource + let resource = DnsRecord::A { + domain: "ns1.google.com".to_string(), + addr: Ipv4Addr::new(8, 8, 8, 8), + ttl: 293, + }; + packet.resources.push(resource); + + packet + .write(&mut buffer) + .expect("Failed to write DNS packet"); + + assert!(buffer.seek(0).is_ok()); + + let read_packet = DnsPacket::from_buffer(&mut buffer).expect("Failed to read DNS packet"); + + // Validate the read packet + assert_eq!(read_packet.questions.len(), 1); + assert_eq!(read_packet.answers.len(), 1); + assert_eq!(read_packet.authorities.len(), 1); + assert_eq!(read_packet.resources.len(), 1); + + // Validate the question + match &read_packet.questions[0] { + DnsQuestion { name, qtype, .. } => { + assert_eq!(name, "google.com"); + assert_eq!(qtype, &QueryType::A); + } + } + + // Validate the answer + match &read_packet.answers[0] { + DnsRecord::A { domain, addr, ttl } => { + assert_eq!(domain, "google.com"); + assert_eq!(addr, &Ipv4Addr::new(216, 58, 211, 142)); + assert_eq!(*ttl, 293); + } + _ => panic!("Expected A record"), + } + + // Validate the authority + match &read_packet.authorities[0] { + DnsRecord::NS { domain, host, ttl } => { + assert_eq!(domain, "google.com"); + assert_eq!(host, "ns1.google.com"); + assert_eq!(*ttl, 293); + } + _ => panic!("Expected NS record"), + } + + // Validate the resource + match &read_packet.resources[0] { + DnsRecord::A { domain, addr, ttl } => { + assert_eq!(domain, "ns1.google.com"); + assert_eq!(addr, &Ipv4Addr::new(8, 8, 8, 8)); + assert_eq!(*ttl, 293); + } + _ => panic!("Expected A record"), + } + } + + #[test] + fn test_get_random_a() { + let mut packet = DnsPacket::new(); + packet.answers.push(DnsRecord::A { + domain: "google.com".to_string(), + addr: Ipv4Addr::new(216, 58, 211, 142), + ttl: 293, + }); + packet.answers.push(DnsRecord::A { + domain: "example.com".to_string(), + addr: Ipv4Addr::new(93, 184, 216, 34), + ttl: 293, + }); + + let random_a = packet.get_random_a(); + assert!(random_a.is_some()); + let random_a = random_a.unwrap(); + assert!( + random_a == Ipv4Addr::new(216, 58, 211, 142) + || random_a == Ipv4Addr::new(93, 184, 216, 34) + ); + } + + #[test] + fn test_get_resolved_ns() { + let mut packet = DnsPacket::new(); + packet.authorities.push(DnsRecord::NS { + domain: "google.com".to_string(), + host: "ns1.google.com".to_string(), + ttl: 293, + }); + packet.resources.push(DnsRecord::A { + domain: "ns1.google.com".to_string(), + addr: Ipv4Addr::new(8, 8, 8, 8), + ttl: 293, + }); + + let resolved_ns = packet.get_resolved_ns("google.com"); + assert_eq!(resolved_ns, Some(Ipv4Addr::new(8, 8, 8, 8))); + } + + #[test] + fn test_get_unresolved_ns() { + let mut packet = DnsPacket::new(); + packet.authorities.push(DnsRecord::NS { + domain: "google.com".to_string(), + host: "ns1.google.com".to_string(), + ttl: 293, + }); + + let unresolved_ns = packet.get_unresolved_ns("google.com"); + assert_eq!(unresolved_ns, Some("ns1.google.com")); + } +} diff --git a/src/dns_question.rs b/src/dns_question.rs new file mode 100644 index 0000000..f2ad13b --- /dev/null +++ b/src/dns_question.rs @@ -0,0 +1,88 @@ +use std::io; + +use crate::{byte_packet_buffer::BytePacketBuffer, query_types::QueryType}; + +#[derive(Debug, Clone)] +pub struct DnsQuestion { + pub name: String, + pub qtype: QueryType, +} + +impl DnsQuestion { + pub fn new(name: String, qtype: QueryType) -> DnsQuestion { + DnsQuestion { name, qtype } + } + /// Read Dns question. + pub fn read(&mut self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { + buffer.read_qname(&mut self.name)?; + self.qtype = QueryType::from_num(buffer.read_u16()?); // qtype + let _ = buffer.read_u16()?; // class + Ok(()) + } + + /// Write Dns question as bytes to buffer. + pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result<(), io::Error> { + buffer.write_qname(&self.name)?; + + let typenum = self.qtype.to_num(); + buffer.write_u16(typenum)?; + buffer.write_u16(1)?; + + Ok(()) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_read_question_from_buffer() { + let mut buf = BytePacketBuffer::new(); + + buf.buf[0] = 6; + buf.buf[1..7].copy_from_slice(b"google"); + buf.buf[7] = 3; + buf.buf[8..11].copy_from_slice(b"com"); + buf.buf[11] = 0; + buf.buf[12] = 0x00; + buf.buf[13] = 0x01; // qtype A + + let mut question = DnsQuestion::new("".to_string(), QueryType::A); + assert!(question.read(&mut buf).is_ok()); + + assert_eq!(question.name, "google.com"); + assert_eq!(question.qtype, QueryType::A); + } + + #[test] + fn test_write_question_to_buffer() { + let question = DnsQuestion::new("google.com".to_string(), QueryType::A); + + let mut buf = BytePacketBuffer::new(); + assert!(question.write(&mut buf).is_ok()); + + assert_eq!(buf.buf[0], 6); + assert_eq!(&buf.buf[1..7], b"google"); + assert_eq!(buf.buf[7], 3); + assert_eq!(&buf.buf[8..11], b"com"); + assert_eq!(buf.buf[11], 0); + assert_eq!(buf.buf[12], 0x00); + assert_eq!(buf.buf[13], 0x01); // qtype A + } + + #[test] + fn test_read_write() { + let question = DnsQuestion::new("google.com".to_string(), QueryType::A); + + let mut buf = BytePacketBuffer::new(); + assert!(question.write(&mut buf).is_ok()); + + let mut read_question = DnsQuestion::new("".to_string(), QueryType::A); + assert!(buf.seek(0).is_ok()); + assert!(read_question.read(&mut buf).is_ok()); + + assert_eq!(question.name, read_question.name); + assert_eq!(question.qtype, read_question.qtype); + } +} diff --git a/src/dns_record.rs b/src/dns_record.rs new file mode 100644 index 0000000..65a2f2e --- /dev/null +++ b/src/dns_record.rs @@ -0,0 +1,679 @@ +use std::{ + io, + net::{Ipv4Addr, Ipv6Addr}, +}; + +use crate::{byte_packet_buffer::BytePacketBuffer, QueryType}; + +#[derive(Debug, Clone, PartialEq, Eq)] +pub enum DnsRecord { + UNKNOWN { + domain: String, + qtype: u16, + data_len: u16, + ttl: u32, + }, // 0 + A { + domain: String, + addr: Ipv4Addr, + ttl: u32, + }, // 1 + NS { + domain: String, + host: String, + ttl: u32, + }, // 2 + CNAME { + domain: String, + host: String, + ttl: u32, + }, // 5 + MX { + domain: String, + priority: u16, + host: String, + ttl: u32, + }, // 15 + AAAA { + domain: String, + addr: Ipv6Addr, + ttl: u32, + }, // 28 +} + +impl DnsRecord { + /// Read Dns record (Answer Section - Authority Section - Additional Section) + pub fn read(buffer: &mut BytePacketBuffer) -> Result { + let mut domain = String::new(); + buffer.read_qname(&mut domain)?; + + let qtype_num = buffer.read_u16()?; + let qtype = QueryType::from_num(qtype_num); + let _ = buffer.read_u16()?; + let ttl = buffer.read_u32()?; + let data_len = buffer.read_u16()?; + + match qtype { + QueryType::A => { + let raw_addr = buffer.read_u32()?; + let addr = Ipv4Addr::new( + ((raw_addr >> 24) & 0xFF) as u8, + ((raw_addr >> 16) & 0xFF) as u8, + ((raw_addr >> 8) & 0xFF) as u8, + ((raw_addr >> 0) & 0xFF) as u8, + ); + + Ok(DnsRecord::A { domain, addr, ttl }) + } + QueryType::AAAA => { + let raw_addr1 = buffer.read_u32()?; + let raw_addr2 = buffer.read_u32()?; + let raw_addr3 = buffer.read_u32()?; + let raw_addr4 = buffer.read_u32()?; + let addr = Ipv6Addr::new( + ((raw_addr1 >> 16) & 0xFFFF) as u16, + ((raw_addr1 >> 0) & 0xFFFF) as u16, + ((raw_addr2 >> 16) & 0xFFFF) as u16, + ((raw_addr2 >> 0) & 0xFFFF) as u16, + ((raw_addr3 >> 16) & 0xFFFF) as u16, + ((raw_addr3 >> 0) & 0xFFFF) as u16, + ((raw_addr4 >> 16) & 0xFFFF) as u16, + ((raw_addr4 >> 0) & 0xFFFF) as u16, + ); + + Ok(DnsRecord::AAAA { domain, addr, ttl }) + } + QueryType::NS => { + let mut ns = String::new(); + buffer.read_qname(&mut ns)?; + + Ok(DnsRecord::NS { + domain, + host: ns, + ttl, + }) + } + QueryType::CNAME => { + let mut cname = String::new(); + buffer.read_qname(&mut cname)?; + + Ok(DnsRecord::CNAME { + domain, + host: cname, + ttl, + }) + } + QueryType::MX => { + let priority = buffer.read_u16()?; + let mut mx = String::new(); + buffer.read_qname(&mut mx)?; + + Ok(DnsRecord::MX { + domain, + priority: priority, + host: mx, + ttl, + }) + } + QueryType::UNKNOWN(_) => { + buffer.step(data_len as usize)?; + + Ok(DnsRecord::UNKNOWN { + domain, + qtype: qtype_num, + data_len: data_len, + ttl, + }) + } + } + } + + /// Write Dns record (Answer Section - Authority Section - Additional Section) + pub fn write(&self, buffer: &mut BytePacketBuffer) -> Result { + let start_pos = buffer.pos(); + + match *self { + DnsRecord::A { + ref domain, + ref addr, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::A.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + buffer.write_u16(4)?; + + let octets = addr.octets(); + buffer.write_u8(octets[0])?; + buffer.write_u8(octets[1])?; + buffer.write_u8(octets[2])?; + buffer.write_u8(octets[3])?; + } + DnsRecord::NS { + ref domain, + ref host, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::NS.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + + let pos = buffer.pos(); + buffer.write_u16(0)?; + + buffer.write_qname(host)?; + + let size = buffer.pos() - (pos + 2); + buffer.set_u16(pos, size as u16)?; + } + DnsRecord::CNAME { + ref domain, + ref host, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::CNAME.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + + let pos = buffer.pos(); + buffer.write_u16(0)?; + + buffer.write_qname(host)?; + + let size = buffer.pos() - (pos + 2); + buffer.set_u16(pos, size as u16)?; + } + DnsRecord::MX { + ref domain, + priority, + ref host, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::MX.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + + let pos = buffer.pos(); + buffer.write_u16(0)?; + + buffer.write_u16(priority)?; + buffer.write_qname(host)?; + + let size = buffer.pos() - (pos + 2); + buffer.set_u16(pos, size as u16)?; + } + DnsRecord::AAAA { + ref domain, + ref addr, + ttl, + } => { + buffer.write_qname(domain)?; + buffer.write_u16(QueryType::AAAA.to_num())?; + buffer.write_u16(1)?; + buffer.write_u32(ttl)?; + buffer.write_u16(16)?; + + for octet in &addr.segments() { + buffer.write_u16(*octet)?; + } + } + DnsRecord::UNKNOWN { .. } => { + println!("Skipping record: {:?}", self); + } + } + + Ok(buffer.pos() - start_pos) + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_read_a_record() { + let mut buf = BytePacketBuffer::new(); + + buf.buf[0] = 6; + buf.buf[1..7].copy_from_slice(b"google"); + buf.buf[7] = 3; + buf.buf[8..11].copy_from_slice(b"com"); + buf.buf[11] = 0; + buf.buf[12] = 0x00; + buf.buf[13] = 0x01; // qtype A + buf.buf[14] = 0x00; + buf.buf[15] = 0x01; // class IN + buf.buf[16] = 0x00; + buf.buf[17] = 0x00; + buf.buf[18] = 0x01; + buf.buf[19] = 0x25; // ttl = 293 + buf.buf[20] = 0x00; + buf.buf[21] = 0x04; // data_len = 4 (IPv4 length) + buf.buf[22] = 0xd8; // 216.58.211.142 + buf.buf[23] = 0x3a; + buf.buf[24] = 0xd3; + buf.buf[25] = 0x8e; + + let record = DnsRecord::read(&mut buf).expect("Failed to read A record"); + + match record { + DnsRecord::A { domain, addr, ttl } => { + assert_eq!(domain, "google.com"); + assert_eq!(addr, Ipv4Addr::new(216, 58, 211, 142)); + assert_eq!(ttl, 293); + } + _ => panic!("Record is not of type A"), + } + } + + #[test] + fn test_write_a_record() { + let record = DnsRecord::A { + domain: "google.com".to_string(), + addr: Ipv4Addr::new(216, 58, 211, 142), + ttl: 293, + }; + + let mut buf = BytePacketBuffer::new(); + record.write(&mut buf).expect("Failed to write A record"); + + assert_eq!(buf.buf[0], 6); + assert_eq!(&buf.buf[1..7], b"google"); + assert_eq!(buf.buf[7], 3); + assert_eq!(&buf.buf[8..11], b"com"); + assert_eq!(buf.buf[11], 0); + assert_eq!(buf.buf[12], 0x00); + assert_eq!(buf.buf[13], 0x01); // qtype A + assert_eq!(buf.buf[14], 0x00); + assert_eq!(buf.buf[15], 0x01); // class IN + assert_eq!(buf.buf[16], 0x00); + assert_eq!(buf.buf[17], 0x00); + assert_eq!(buf.buf[18], 0x01); + assert_eq!(buf.buf[19], 0x25); // ttl = 293 + assert_eq!(buf.buf[20], 0x00); + assert_eq!(buf.buf[21], 0x04); // data_len = 4 + assert_eq!(&buf.buf[22..26], &[216, 58, 211, 142]); + } + + #[test] + fn test_a_record() { + let record = DnsRecord::A { + domain: "google.com".to_string(), + addr: Ipv4Addr::new(216, 58, 211, 142), + ttl: 293, + }; + + let mut buf = BytePacketBuffer::new(); + record.write(&mut buf).expect("Failed to write A record"); + assert!(buf.seek(0).is_ok()); + let read_record = DnsRecord::read(&mut buf).expect("Failed to read A record"); + + assert_eq!(record, read_record); + } + + #[test] + fn test_read_aaaa_record() { + let mut buf = BytePacketBuffer::new(); + + buf.buf[0] = 6; + buf.buf[1..7].copy_from_slice(b"google"); + buf.buf[7] = 3; + buf.buf[8..11].copy_from_slice(b"com"); + buf.buf[11] = 0; + buf.buf[12] = 0x00; + buf.buf[13] = 0x1C; // qtype AAAA + buf.buf[14] = 0x00; + buf.buf[15] = 0x01; // class IN + buf.buf[16] = 0x00; + buf.buf[17] = 0x00; + buf.buf[18] = 0x01; + buf.buf[19] = 0x25; // ttl = 293 + buf.buf[20] = 0x00; + buf.buf[21] = 0x10; // data_len = 16 (IPv6 length) + buf.buf[22..38].copy_from_slice(&[ + 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + ]); + + let record = DnsRecord::read(&mut buf).expect("Failed to read AAAA record"); + + match record { + DnsRecord::AAAA { domain, addr, ttl } => { + assert_eq!(domain, "google.com"); + assert_eq!( + addr, + Ipv6Addr::new(0x2001, 0xdb8, 0x0010, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001,) + ); + assert_eq!(ttl, 293); + } + _ => panic!("Record is not of type AAAA"), + } + } + + #[test] + fn test_write_aaaa_record() { + let record = DnsRecord::AAAA { + domain: "google.com".to_string(), + addr: Ipv6Addr::new( + 0x2001, 0xdb8, 0x0010, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, + ), + ttl: 293, + }; + + let mut buf = BytePacketBuffer::new(); + record.write(&mut buf).expect("Failed to write AAAA record"); + + assert_eq!(buf.buf[0], 6); + assert_eq!(&buf.buf[1..7], b"google"); + assert_eq!(buf.buf[7], 3); + assert_eq!(&buf.buf[8..11], b"com"); + assert_eq!(buf.buf[11], 0); + assert_eq!(buf.buf[12], 0x00); + assert_eq!(buf.buf[13], 0x1C); // qtype AAAA + assert_eq!(buf.buf[14], 0x00); + assert_eq!(buf.buf[15], 0x01); // class IN + assert_eq!(buf.buf[16], 0x00); + assert_eq!(buf.buf[17], 0x00); + assert_eq!(buf.buf[18], 0x01); + assert_eq!(buf.buf[19], 0x25); // ttl = 293 + assert_eq!(buf.buf[20], 0x00); + assert_eq!(buf.buf[21], 0x10); // data_len = 16 + assert_eq!( + &buf.buf[22..38], + &[ + 0x20, 0x01, 0x0d, 0xb8, 0x00, 0x10, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, 0x00, + 0x00, 0x01, + ] + ); + } + + #[test] + fn test_aaaa_record() { + let record = DnsRecord::AAAA { + domain: "google.com".to_string(), + addr: Ipv6Addr::new( + 0x2001, 0xdb8, 0x0010, 0x0000, 0x0000, 0x0000, 0x0000, 0x0001, + ), + ttl: 293, + }; + + let mut buf = BytePacketBuffer::new(); + record.write(&mut buf).expect("Failed to write AAAA record"); + assert!(buf.seek(0).is_ok()); + let read_record = DnsRecord::read(&mut buf).expect("Failed to read AAAA record"); + + assert_eq!(record, read_record); + } + + #[test] + fn test_read_cname_record() { + let mut buf = BytePacketBuffer::new(); + + buf.buf[0] = 6; + buf.buf[1..7].copy_from_slice(b"google"); + buf.buf[7] = 3; + buf.buf[8..11].copy_from_slice(b"com"); + buf.buf[11] = 0; + buf.buf[12] = 0x00; + buf.buf[13] = 0x05; // qtype CNAME + buf.buf[14] = 0x00; + buf.buf[15] = 0x01; // class IN + buf.buf[16] = 0x00; + buf.buf[17] = 0x00; + buf.buf[18] = 0x01; + buf.buf[19] = 0x25; // ttl = 293 + buf.buf[20] = 0x00; + buf.buf[21] = 0x05; // data_len = 3 + buf.buf[22] = 0x03; + buf.buf[23..26].copy_from_slice(b"www"); + + let record = DnsRecord::read(&mut buf).expect("Failed to read CNAME record"); + + match record { + DnsRecord::CNAME { domain, host, ttl } => { + assert_eq!(domain, "google.com"); + assert_eq!(host, "www"); + assert_eq!(ttl, 293); + } + _ => panic!("Record is not of type CNAME"), + } + } + + #[test] + fn test_write_cname_record() { + let record = DnsRecord::CNAME { + domain: "google.com".to_string(), + host: "www".to_string(), + ttl: 293, + }; + + let mut buf = BytePacketBuffer::new(); + record + .write(&mut buf) + .expect("Failed to write CNAME record"); + + assert_eq!(buf.buf[0], 6); + assert_eq!(&buf.buf[1..7], b"google"); + assert_eq!(buf.buf[7], 3); + assert_eq!(&buf.buf[8..11], b"com"); + assert_eq!(buf.buf[11], 0); + assert_eq!(buf.buf[12], 0x00); + assert_eq!(buf.buf[13], 0x05); // qtype CNAME + assert_eq!(buf.buf[14], 0x00); + assert_eq!(buf.buf[15], 0x01); // class IN + assert_eq!(buf.buf[16], 0x00); + assert_eq!(buf.buf[17], 0x00); + assert_eq!(buf.buf[18], 0x01); + assert_eq!(buf.buf[19], 0x25); // ttl = 293 + assert_eq!(buf.buf[20], 0x00); + assert_eq!(buf.buf[21], 0x05); // data_len = 3 + assert_eq!(buf.buf[22], 3); + assert_eq!(&buf.buf[23..26], b"www"); + } + + #[test] + fn test_cname_record() { + let record = DnsRecord::CNAME { + domain: "google.com".to_string(), + host: "www".to_string(), + ttl: 293, + }; + + let mut buf = BytePacketBuffer::new(); + record + .write(&mut buf) + .expect("Failed to write CNAME record"); + assert!(buf.seek(0).is_ok()); + let read_record = DnsRecord::read(&mut buf).expect("Failed to read CNAME record"); + + assert_eq!(record, read_record); + } + + #[test] + fn test_read_mx_record() { + let mut buf = BytePacketBuffer::new(); + + buf.buf[0] = 6; + buf.buf[1..7].copy_from_slice(b"google"); + buf.buf[7] = 3; + buf.buf[8..11].copy_from_slice(b"com"); + buf.buf[11] = 0; + buf.buf[12] = 0x00; + buf.buf[13] = 0x0F; // qtype MX + buf.buf[14] = 0x00; + buf.buf[15] = 0x01; // class IN + buf.buf[16] = 0x00; + buf.buf[17] = 0x00; + buf.buf[18] = 0x01; + buf.buf[19] = 0x25; // ttl = 293 + buf.buf[20] = 0x00; + buf.buf[21] = 0x09; // data_len = 9 + buf.buf[22] = 0x00; + buf.buf[23] = 0x0A; // priority = 10 + buf.buf[24] = 0x03; + buf.buf[25..28].copy_from_slice(b"mx1"); + buf.buf[28] = 3; + buf.buf[29..32].copy_from_slice(b"com"); + buf.buf[32] = 0; + + let record = DnsRecord::read(&mut buf).expect("Failed to read MX record"); + + match record { + DnsRecord::MX { + domain, + priority, + host, + ttl, + } => { + assert_eq!(domain, "google.com"); + assert_eq!(priority, 10); + assert_eq!(host, "mx1.com"); + assert_eq!(ttl, 293); + } + _ => panic!("Record is not of type MX"), + } + } + + #[test] + fn test_write_mx_record() { + let record = DnsRecord::MX { + domain: "google.com".to_string(), + priority: 10, + host: "mx1.com".to_string(), + ttl: 293, + }; + + let mut buf = BytePacketBuffer::new(); + record.write(&mut buf).expect("Failed to write MX record"); + + assert_eq!(buf.buf[0], 6); + assert_eq!(&buf.buf[1..7], b"google"); + assert_eq!(buf.buf[7], 3); + assert_eq!(&buf.buf[8..11], b"com"); + assert_eq!(buf.buf[11], 0); + assert_eq!(buf.buf[12], 0x00); + assert_eq!(buf.buf[13], 0x0F); // qtype MX + assert_eq!(buf.buf[14], 0x00); + assert_eq!(buf.buf[15], 0x01); // class IN + assert_eq!(buf.buf[16], 0x00); + assert_eq!(buf.buf[17], 0x00); + assert_eq!(buf.buf[18], 0x01); + assert_eq!(buf.buf[19], 0x25); // ttl = 293 + assert_eq!(buf.buf[20], 0x00); + assert_eq!(buf.buf[21], 0x0b); // data_len = 9 + assert_eq!(buf.buf[22], 0x00); + assert_eq!(buf.buf[23], 0x0A); // priority = 10 + assert_eq!(buf.buf[24], 0x03); + assert_eq!(&buf.buf[25..28], b"mx1"); + assert_eq!(buf.buf[28], 3); + assert_eq!(&buf.buf[29..32], b"com"); + assert_eq!(buf.buf[32], 0); + } + + #[test] + fn test_mx_record() { + let record = DnsRecord::MX { + domain: "google.com".to_string(), + priority: 10, + host: "mx1.com".to_string(), + ttl: 293, + }; + + let mut buf = BytePacketBuffer::new(); + record.write(&mut buf).expect("Failed to write MX record"); + assert!(buf.seek(0).is_ok()); + let read_record = DnsRecord::read(&mut buf).expect("Failed to read MX record"); + + assert_eq!(record, read_record); + } + + #[test] + fn test_read_ns_record() { + let mut buf = BytePacketBuffer::new(); + + buf.buf[0] = 6; + buf.buf[1..7].copy_from_slice(b"google"); + buf.buf[7] = 3; + buf.buf[8..11].copy_from_slice(b"com"); + buf.buf[11] = 0; + buf.buf[12] = 0x00; + buf.buf[13] = 0x02; // qtype NS + buf.buf[14] = 0x00; + buf.buf[15] = 0x01; // class IN + buf.buf[16] = 0x00; + buf.buf[17] = 0x00; + buf.buf[18] = 0x01; + buf.buf[19] = 0x25; // ttl = 293 + buf.buf[20] = 0x00; + buf.buf[21] = 0x09; // data_len = 9 + buf.buf[22] = 0x03; + buf.buf[23..26].copy_from_slice(b"ns1"); + buf.buf[26] = 3; + buf.buf[27..30].copy_from_slice(b"com"); + buf.buf[30] = 0; + + let record = DnsRecord::read(&mut buf).expect("Failed to read NS record"); + + match record { + DnsRecord::NS { domain, host, ttl } => { + assert_eq!(domain, "google.com"); + assert_eq!(host, "ns1.com"); + assert_eq!(ttl, 293); + } + _ => panic!("Record is not of type NS"), + } + } + + #[test] + fn test_write_ns_record() { + let record = DnsRecord::NS { + domain: "google.com".to_string(), + host: "ns1.com".to_string(), + ttl: 293, + }; + + let mut buf = BytePacketBuffer::new(); + record.write(&mut buf).expect("Failed to write NS record"); + + assert_eq!(buf.buf[0], 6); + assert_eq!(&buf.buf[1..7], b"google"); + assert_eq!(buf.buf[7], 3); + assert_eq!(&buf.buf[8..11], b"com"); + assert_eq!(buf.buf[11], 0); + assert_eq!(buf.buf[12], 0x00); + assert_eq!(buf.buf[13], 0x02); // qtype NS + assert_eq!(buf.buf[14], 0x00); + assert_eq!(buf.buf[15], 0x01); // class IN + assert_eq!(buf.buf[16], 0x00); + assert_eq!(buf.buf[17], 0x00); + assert_eq!(buf.buf[18], 0x01); + assert_eq!(buf.buf[19], 0x25); // ttl = 293 + assert_eq!(buf.buf[20], 0x00); + assert_eq!(buf.buf[21], 0x09); // data_len = 9 + assert_eq!(buf.buf[22], 0x03); + assert_eq!(&buf.buf[23..26], b"ns1"); + assert_eq!(buf.buf[26], 3); + assert_eq!(&buf.buf[27..30], b"com"); + assert_eq!(buf.buf[30], 0); + } + + #[test] + fn test_ns_record() { + let record = DnsRecord::NS { + domain: "google.com".to_string(), + host: "ns1.com".to_string(), + ttl: 293, + }; + + let mut buf = BytePacketBuffer::new(); + record.write(&mut buf).expect("Failed to write NS record"); + assert!(buf.seek(0).is_ok()); + let read_record = DnsRecord::read(&mut buf).expect("Failed to read NS record"); + + assert_eq!(record, read_record); + } +} diff --git a/src/lib.rs b/src/lib.rs new file mode 100644 index 0000000..0ffa0d9 --- /dev/null +++ b/src/lib.rs @@ -0,0 +1,139 @@ +use std::{ + io::{self}, + net::{Ipv4Addr, UdpSocket}, +}; +mod byte_packet_buffer; +mod dns_header; +mod dns_packet; +mod dns_question; +mod dns_record; +mod query_types; +use { + byte_packet_buffer::BytePacketBuffer, dns_header::ResultCode, dns_packet::DnsPacket, + dns_question::DnsQuestion, query_types::QueryType, +}; + +fn recursive_lookup( + qname: &str, + qtype: QueryType, + root_ip: &Option, +) -> Result { + // For now we're always starting with *a.root-servers.net* if no forwarding ip specified. + let mut ns = match root_ip { + Some(ip) => ip, + None => "198.41.0.4", + } + .parse::() + .unwrap(); + + loop { + println!("attempting lookup of {:?} {} with ns {}", qtype, qname, ns); + + let ns_copy = ns; + + let server = (ns_copy, 53); + let response = lookup(qname, qtype, server)?; + + if !response.answers.is_empty() && response.header.rescode == ResultCode::NOERROR { + return Ok(response); + } + + if response.header.rescode == ResultCode::NXDOMAIN { + return Ok(response); + } + + if let Some(new_ns) = response.get_resolved_ns(qname) { + ns = new_ns; + + continue; + } + + let new_ns_name = match response.get_unresolved_ns(qname) { + Some(x) => x, + None => return Ok(response), + }; + + let recursive_response = recursive_lookup(&new_ns_name, QueryType::A, root_ip)?; + + if let Some(new_ns) = recursive_response.get_random_a() { + ns = new_ns; + } else { + return Ok(response); + } + } +} + +fn lookup(qname: &str, qtype: QueryType, server: (Ipv4Addr, u16)) -> Result { + let socket = UdpSocket::bind(("0.0.0.0", 49152))?; + + let mut packet = DnsPacket::new(); + + packet.header.id = 6666; + packet.header.questions = 1; + packet.header.recursion_desired = true; + packet + .questions + .push(DnsQuestion::new(qname.to_string(), qtype)); + + let mut req_buffer = BytePacketBuffer::new(); + packet.write(&mut req_buffer)?; + socket.send_to(&req_buffer.get_range(0, req_buffer.pos() + 1)?, server)?; + + let mut res_buffer = BytePacketBuffer::new(); + socket.recv_from(&mut res_buffer.buf)?; + + DnsPacket::from_buffer(&mut res_buffer) +} + +/// Handle query +/// function handles an incoming DNS query from a UDP socket +/// and formulates an appropriate DNS response, +/// either answering directly or using recursive resolution to obtain the necessary data +pub fn handle_query(socket: &UdpSocket, root_ip: &Option) -> Result<(), io::Error> { + let mut req_buffer = BytePacketBuffer::new(); + + let (_, src) = socket.recv_from(&mut req_buffer.buf)?; + + let mut request = DnsPacket::from_buffer(&mut req_buffer)?; + + let mut packet = DnsPacket::new(); + packet.header.id = request.header.id; + packet.header.recursion_desired = true; + packet.header.recursion_available = true; + packet.header.response = true; + + if let Some(question) = request.questions.pop() { + println!("Received query: {:?}", question); + if let Ok(result) = recursive_lookup(&question.name, question.qtype, root_ip) { + packet.questions.push(question.clone()); + packet.header.rescode = result.header.rescode; + + for rec in result.answers { + println!("Answer: {:?}", rec); + packet.answers.push(rec); + } + for rec in result.authorities { + println!("Authority: {:?}", rec); + packet.authorities.push(rec); + } + for rec in result.resources { + println!("Resource: {:?}", rec); + packet.resources.push(rec); + } + } else { + packet.header.rescode = ResultCode::SERVFAIL; + } + } else { + packet.header.rescode = ResultCode::FORMERR; + } + + let mut res_buffer = BytePacketBuffer::new(); + packet.write(&mut res_buffer)?; + + let len = res_buffer.pos(); + let data = res_buffer.get_range(0, len)?; + + socket.send_to(data, src)?; + + Ok(()) +} diff --git a/src/main.rs b/src/main.rs new file mode 100644 index 0000000..c5681a9 --- /dev/null +++ b/src/main.rs @@ -0,0 +1,31 @@ +use clap::Parser; +use std::{io, net::UdpSocket}; + +#[derive(Parser)] +#[command(name = "DnsServerApp")] +#[command(version = "1.0")] +#[command(about = "A simple DNS server application", long_about = None)] +struct Args { + /// Port to bind the UDP socket + #[arg(short, long, default_value_t = 2053)] + port: u16, + /// forward replies to specified dns server + #[arg(short, long, default_value = None)] + forward_ip: Option, +} + +fn main() -> Result<(), io::Error> { + let args = Args::parse(); + // Bind an UDP socket on port + let socket = UdpSocket::bind(("0.0.0.0", args.port))?; + println!("Server listening on port {}", args.port); + + // For now, queries are handled sequentially, so an infinite loop for servicing + // requests is initiated. + loop { + match dnsserver_nabil::handle_query(&socket, &args.forward_ip) { + Ok(_) => {} + Err(e) => eprintln!("An error occurred: {}", e), + } + } +} diff --git a/src/query_types.rs b/src/query_types.rs new file mode 100644 index 0000000..bb0f36a --- /dev/null +++ b/src/query_types.rs @@ -0,0 +1,33 @@ +#[derive(PartialEq, Eq, Debug, Clone, Copy)] +pub enum QueryType { + UNKNOWN(u16), + A, // 1 + NS, // 2 + CNAME, // 5 + MX, // 15 + AAAA, // 28 +} + +impl QueryType { + pub fn to_num(&self) -> u16 { + match *self { + QueryType::UNKNOWN(x) => x, + QueryType::A => 1, + QueryType::NS => 2, + QueryType::CNAME => 5, + QueryType::MX => 15, + QueryType::AAAA => 28, + } + } + + pub fn from_num(num: u16) -> QueryType { + match num { + 1 => QueryType::A, + 2 => QueryType::NS, + 5 => QueryType::CNAME, + 15 => QueryType::MX, + 28 => QueryType::AAAA, + _ => QueryType::UNKNOWN(num), + } + } +}