diff --git a/kernel/networking/drivers/net_driver.hpp b/kernel/networking/drivers/net_driver.hpp index df150e9f..df69b3d0 100644 --- a/kernel/networking/drivers/net_driver.hpp +++ b/kernel/networking/drivers/net_driver.hpp @@ -1,28 +1,21 @@ #pragma once #include "types.h" -#include "std/string.h" -#include "ui/graphic_types.h" -#include "net/network_types.h" +#include "net/network_types.h" class NetDriver { public: - NetDriver(){} + NetDriver() = default; virtual bool init() = 0; virtual sizedptr allocate_packet(size_t size) = 0; - virtual sizedptr handle_receive_packet(void* buffer) = 0; - virtual void handle_sent_packet() = 0; - virtual void enable_verbose() = 0; - virtual void send_packet(sizedptr packet) = 0; - - virtual void get_mac(network_connection_ctx *context) = 0; + virtual void get_mac(net_l2l3_endpoint *context) = 0; virtual ~NetDriver() = default; uint16_t header_size; -}; \ No newline at end of file +}; diff --git a/kernel/networking/drivers/virtio_net_pci/virtio_net_pci.cpp b/kernel/networking/drivers/virtio_net_pci/virtio_net_pci.cpp index 6efaee22..76ff517c 100644 --- a/kernel/networking/drivers/virtio_net_pci/virtio_net_pci.cpp +++ b/kernel/networking/drivers/virtio_net_pci/virtio_net_pci.cpp @@ -106,7 +106,7 @@ bool VirtioNetDriver::init(){ } -void VirtioNetDriver::get_mac(network_connection_ctx *context){ +void VirtioNetDriver::get_mac(net_l2l3_endpoint *context){ virtio_net_config* net_config = (virtio_net_config*)vnp_net_dev.device_cfg; kprintfv("[VIRTIO_NET] %x:%x:%x:%x:%x:%x", net_config->mac[0], net_config->mac[1], net_config->mac[2], net_config->mac[3], net_config->mac[4], net_config->mac[5]); diff --git a/kernel/networking/drivers/virtio_net_pci/virtio_net_pci.hpp b/kernel/networking/drivers/virtio_net_pci/virtio_net_pci.hpp index 22f49945..86550911 100644 --- a/kernel/networking/drivers/virtio_net_pci/virtio_net_pci.hpp +++ b/kernel/networking/drivers/virtio_net_pci/virtio_net_pci.hpp @@ -15,13 +15,14 @@ class VirtioNetDriver : public NetDriver { sizedptr allocate_packet(size_t size) override; sizedptr handle_receive_packet(void* buffer) override; + void handle_sent_packet() override; void enable_verbose() override; void send_packet(sizedptr packet) override; - void get_mac(network_connection_ctx *context) override; + void get_mac(net_l2l3_endpoint *context) override; ~VirtioNetDriver() = default; diff --git a/kernel/networking/network.cpp b/kernel/networking/network.cpp index fe4a170d..a68fd2b0 100644 --- a/kernel/networking/network.cpp +++ b/kernel/networking/network.cpp @@ -1,52 +1,47 @@ #include "network.h" #include "network_dispatch.hpp" -#include "std/allocator.hpp" #include "process/scheduler.h" -NetworkDispatch *dispatch; +static NetworkDispatch *dispatch = nullptr; -bool network_init(){ +extern "C" bool network_init() { dispatch = new NetworkDispatch(); - return dispatch->init(); + return dispatch && dispatch->init(); } -void network_handle_download_interrupt(){ - return dispatch->handle_download_interrupt(); +extern "C" void network_handle_download_interrupt() { + if (dispatch) dispatch->handle_download_interrupt(); } -void network_handle_upload_interrupt(){ - return dispatch->handle_upload_interrupt(); +extern "C" void network_handle_upload_interrupt() { + if (dispatch) dispatch->handle_upload_interrupt(); } - -bool network_bind_port(uint16_t port, uint16_t process){ - return dispatch->bind_port(port, process); -} - -bool network_unbind_port(uint16_t port, uint16_t process){ - return dispatch->unbind_port(port, process); +extern "C" void network_net_task_entry() { + if (dispatch) dispatch->net_task(); } -bool network_bind_port_current(uint16_t port){ - return dispatch->bind_port(port, get_current_proc_pid()); +extern "C" int net_tx_frame(uintptr_t frame_ptr, uint32_t frame_len) { + if (!dispatch || !frame_ptr || !frame_len) return -1; + return dispatch->enqueue_frame({frame_ptr, frame_len}) ? 0 : -1; } -bool network_unbind_port_current(uint16_t port){ - return dispatch->unbind_port(port, get_current_proc_pid()); +extern "C" int net_rx_frame(sizedptr *out_frame) { + extern uint16_t get_current_proc_pid(); + if (!dispatch || !out_frame) return -1; + int sz = dispatch->dequeue_packet_for(get_current_proc_pid(), out_frame) ? (int)out_frame->size : 0; + return sz; } -void network_send_packet(NetProtocol protocol, uint16_t port, network_connection_ctx *destination, void* payload, uint16_t payload_len){ - return dispatch->send_packet(protocol, port, destination, payload, payload_len); +extern "C" const net_l2l3_endpoint* network_get_local_endpoint() { + static net_l2l3_endpoint dummy = {0}; + return dispatch ? &dispatch->get_local_ep() : &dummy; } -bool network_read_packet(sizedptr *packet, uint16_t process){ - return dispatch->read_packet(packet, process); +extern "C" void network_net_set_pid(uint16_t pid) { + if (dispatch) dispatch->set_net_pid(pid); } -bool network_read_packet_current(sizedptr *packet){ - return dispatch->read_packet(packet, get_current_proc_pid()); +extern "C" uint16_t network_net_get_pid() { + return dispatch ? dispatch->get_net_pid() : UINT16_MAX; } - -network_connection_ctx* network_get_context(){ - return dispatch->get_context(); -} \ No newline at end of file diff --git a/kernel/networking/network.h b/kernel/networking/network.h index fbc2473b..80b789a2 100644 --- a/kernel/networking/network.h +++ b/kernel/networking/network.h @@ -8,25 +8,23 @@ extern "C" { #include "net/network_types.h" #define NET_IRQ 32 - -//TODO: review this number +//TODO: consider using the system MTU here #define MAX_PACKET_SIZE 0x1000 +void network_net_set_pid(uint16_t pid); +uint16_t network_net_get_pid(); + bool network_init(); void network_handle_download_interrupt(); void network_handle_upload_interrupt(); -bool network_bind_port(uint16_t port, uint16_t process); -bool network_unbind_port(uint16_t port, uint16_t process); -void network_send_packet(NetProtocol protocol, uint16_t port, network_connection_ctx *destination, void* payload, uint16_t payload_len); +void network_net_task_entry(); -bool network_bind_port_current(uint16_t port); -bool network_unbind_port_current(uint16_t port); +int net_tx_frame(uintptr_t frame_ptr, uint32_t frame_len); +int net_rx_frame(sizedptr *out_frame); -bool network_read_packet(sizedptr *packet, uint16_t process); -bool network_read_packet_current(sizedptr *packet); - -network_connection_ctx* network_get_context(); +const net_l2l3_endpoint* network_get_local_endpoint(); +void network_update_local_ip(uint32_t ip); #ifdef __cplusplus } -#endif \ No newline at end of file +#endif diff --git a/kernel/networking/network_dispatch.cpp b/kernel/networking/network_dispatch.cpp index 590de3db..901b23e6 100644 --- a/kernel/networking/network_dispatch.cpp +++ b/kernel/networking/network_dispatch.cpp @@ -1,164 +1,165 @@ #include "network_dispatch.hpp" +#include "network.h" #include "drivers/virtio_net_pci/virtio_net_pci.hpp" -#include "net/network_types.h" -#include "console/kio.h" -#include "process/scheduler.h" -#include "net/udp.h" -#include "net/tcp.h" -#include "net/dhcp.h" -#include "net/arp.h" -#include "net/eth.h" -#include "net/ipv4.h" -#include "net/icmp.h" #include "memory/page_allocator.h" +#include "net/link_layer/eth.h" +#include "net/network_types.h" +#include "port_manager.h" #include "std/memfunctions.h" -#include "hw/hw.h" -#include "network.h" -NetworkDispatch::NetworkDispatch(){ - ports = IndexMap(UINT16_MAX); - for (uint16_t i = 0; i < UINT16_MAX; i++) +extern void sleep(uint64_t ms); +extern uintptr_t malloc(uint64_t size); +extern void free(void *ptr, uint64_t size); + +static uint16_t g_net_pid = 0xFFFF; +static uint8_t recv_buffer[MAX_PACKET_SIZE]; + +NetworkDispatch::NetworkDispatch() + : ports(UINT16_MAX + 1), + driver(nullptr), + tx_queue(QUEUE_CAPACITY), + rx_queue(QUEUE_CAPACITY) +{ + for (uint32_t i = 0; i <= UINT16_MAX; ++i) ports[i] = UINT16_MAX; - context = (network_connection_ctx) {0}; + + memset(local_mac.mac, 0, sizeof(local_mac.mac)); } -NetDriver* NetworkDispatch::select_driver(){ - return BOARD_TYPE == 1 ? VirtioNetDriver::try_init() : 0x0; +bool NetworkDispatch::init() +{ + driver = VirtioNetDriver::try_init(); + if (!driver) return false; + driver->get_mac(&local_mac); + return true; } -bool NetworkDispatch::init(){ - if ((driver = select_driver())){ - driver->get_mac(&context); - return true; +void NetworkDispatch::handle_download_interrupt() +{ + if (!driver) return; + + sizedptr raw = driver->handle_receive_packet(recv_buffer); + if (raw.size < sizeof(eth_hdr_t)) { + return; } - return false; + + sizedptr frame{0, raw.size}; + frame.ptr = reinterpret_cast( + kalloc(reinterpret_cast(get_current_heap()), + raw.size, ALIGN_16B, + get_current_privilege(), false)); + if (!frame.ptr) return; + + memcpy(reinterpret_cast(frame.ptr), recv_buffer, raw.size); + + if (!rx_queue.enqueue(frame)) + free_frame(frame); } -bool NetworkDispatch::bind_port(uint16_t port, uint16_t process){ - if (ports[port] != UINT16_MAX) return false; - ports[port] = process; - return true; +void NetworkDispatch::handle_upload_interrupt() +{ + if (driver) + driver->handle_sent_packet(); } -bool NetworkDispatch::unbind_port(uint16_t port, uint16_t process){ - if (ports[port] != process) return false; - ports[port] = UINT16_MAX; +bool NetworkDispatch::enqueue_frame(const sizedptr &frame) +{ + if (frame.size == 0) return false; + + sizedptr pkt = driver->allocate_packet(frame.size); + if (!pkt.ptr) return false; + + void* dst = reinterpret_cast(pkt.ptr + driver->header_size); + memcpy(dst, reinterpret_cast(frame.ptr), frame.size); + + if (!tx_queue.enqueue(pkt)) { + free_frame(pkt); + return false; + } return true; } -void NetworkDispatch::handle_download_interrupt(){ - if (driver){ - void *buffer = kalloc((void*)get_current_heap(), MAX_PACKET_SIZE, ALIGN_16B, get_current_privilege(), false); - sizedptr packet = driver->handle_receive_packet(buffer); - bool need_free = true; - uintptr_t ptr = packet.ptr; - if (ptr){ - eth_hdr_t *eth = (eth_hdr_t*)ptr; - uint16_t ethtype = eth_parse_packet_type(ptr); - ptr += sizeof(eth_hdr_t); - if (ethtype == 0x806){ - arp_hdr_t *arp = (arp_hdr_t*)ptr; - if (arp_should_handle(arp, get_context()->ip)){ - kprintf("Received an ARP request"); - bool req = 0; - network_connection_ctx conn; - arp_populate_response(&conn, arp); - send_packet(ARP, 0, &conn, &req, 1); - } - //TODO: Should also look for responses to our own queries - } else if (ethtype == 0x800){//IPV4 - ipv4_hdr_t *ipv4 = (ipv4_hdr_t*)ptr; - uint8_t protocol = ipv4_get_protocol(ptr); - ptr += sizeof(ipv4_hdr_t); - if (protocol == 0x11 || protocol == 0x06){ - uint16_t port = udp_parse_packet(ptr); - if (ports[port] != UINT16_MAX){ - process_t *proc = get_proc_by_pid(ports[port]); - if (!proc) - unbind_port(port, ports[port]); - else { - packet_buffer_t* buf = &proc->packet_buffer; - uint32_t next_index = (buf->write_index + 1) % PACKET_BUFFER_CAPACITY; - - buf->entries[buf->write_index] = packet; - buf->write_index = next_index; - - need_free = false; - - if (buf->write_index == buf->read_index) - buf->read_index = (buf->read_index + 1) % PACKET_BUFFER_CAPACITY; - } - } - } else if (protocol == 0x1) { - icmp_data data = (icmp_data){ - .response = true - }; - network_connection_ctx conn; - icmp_packet *icmp = (icmp_packet*)ptr; - data.seq = icmp_get_sequence(icmp); - data.id = icmp_get_id(icmp); - icmp_copy_payload(&data.payload, icmp); - ipv4_populate_response(&conn, eth, ipv4); - send_packet(ICMP, 0, &conn, &data, sizeof(icmp_data)); - } - } +void NetworkDispatch::net_task() +{ + for (;;) { + bool did_work = false; + sizedptr pkt; + + //rx + if (!rx_queue.is_empty() && rx_queue.dequeue(pkt)) { + did_work = true; + eth_input(pkt.ptr, pkt.size); + free_frame(pkt); } - if (need_free){ - kfree(buffer, MAX_PACKET_SIZE); + + //tx + if (!tx_queue.is_empty() && tx_queue.dequeue(pkt)) { + did_work = true; + driver->send_packet(pkt); } + + if (!did_work) + sleep(10); } } -void NetworkDispatch::handle_upload_interrupt(){ - driver->handle_sent_packet(); -} +bool NetworkDispatch::dequeue_packet_for(uint16_t pid, sizedptr *out) +{ + process_t *proc = get_proc_by_pid(pid); + if (!proc || !out) return false; -bool NetworkDispatch::read_packet(sizedptr *Packet, uint16_t process){ - process_t *proc = get_proc_by_pid(process); - if (proc->packet_buffer.read_index == proc->packet_buffer.write_index) return false; + auto &buf = proc->packet_buffer; + if (buf.read_index == buf.write_index) return false; - sizedptr original = proc->packet_buffer.entries[proc->packet_buffer.read_index]; - - uintptr_t copy = (uintptr_t)kalloc((void*)get_current_heap(), original.size, ALIGN_16B, get_current_privilege(), false); - memcpy((void*)copy,(void*)original.ptr,original.size); - Packet->ptr = copy; - Packet->size = original.size; - free_sized(original); - proc->packet_buffer.read_index = (proc->packet_buffer.read_index + 1) % PACKET_BUFFER_CAPACITY; + sizedptr stored = buf.entries[buf.read_index]; + buf.read_index = (buf.read_index + 1) % PACKET_BUFFER_CAPACITY; + + void *dst = kalloc(reinterpret_cast(get_current_heap()), + stored.size, ALIGN_16B, + get_current_privilege(), false); + if (!dst) return false; + + memcpy(dst, reinterpret_cast(stored.ptr), stored.size); + out->ptr = reinterpret_cast(dst); + out->size = stored.size; + + free(reinterpret_cast(stored.ptr), stored.size); return true; } -void NetworkDispatch::send_packet(NetProtocol protocol, uint16_t port, network_connection_ctx *destination, void* payload, uint16_t payload_len){ - sizedptr packet_buffer; - switch (protocol) { - case UDP: - packet_buffer = driver->allocate_packet(sizeof(eth_hdr_t) + sizeof(ipv4_hdr_t) + sizeof(udp_hdr_t) + payload_len); - context.port = port; - create_udp_packet(packet_buffer.ptr + driver->header_size, context, *destination, (sizedptr){(uintptr_t)payload, payload_len}); - break; - case DHCP: - packet_buffer = driver->allocate_packet(DHCP_SIZE); - create_dhcp_packet(packet_buffer.ptr + driver->header_size, (dhcp_request*)payload); - break; - case ARP: - packet_buffer = driver->allocate_packet(sizeof(eth_hdr_t) + sizeof(arp_hdr_t)); - create_arp_packet(packet_buffer.ptr + driver->header_size, context.mac, context.ip, destination->mac, destination->ip, *(bool*)payload); - break; - case ICMP: - packet_buffer = driver->allocate_packet(sizeof(eth_hdr_t) + sizeof(ipv4_hdr_t) + sizeof(icmp_packet)); - create_icmp_packet(packet_buffer.ptr + driver->header_size, context, *destination, (icmp_data*)payload); - break; - case TCP: - tcp_data *data = (tcp_data*)payload; - packet_buffer = driver->allocate_packet(sizeof(eth_hdr_t) + sizeof(ipv4_hdr_t) + sizeof(tcp_hdr_t) + data->options.size + data->payload.size); - context.port = port; - create_tcp_packet(packet_buffer.ptr + driver->header_size, context, *destination, (sizedptr){(uintptr_t)data, sizeof(tcp_data)}); - break; - } - if (driver) - driver->send_packet(packet_buffer); +static sizedptr make_user_copy(const sizedptr &src) +{ + sizedptr out{0, 0}; + uintptr_t mem = malloc(src.size); + if (!mem) return out; + + memcpy(reinterpret_cast(mem), + reinterpret_cast(src.ptr), + src.size); + + out.ptr = mem; + out.size = src.size; + return out; +} + +sizedptr NetworkDispatch::make_copy(const sizedptr &in) +{ + sizedptr out{0, 0}; + void *dst = kalloc(reinterpret_cast(get_current_heap()), + in.size, ALIGN_16B, + get_current_privilege(), false); + if (!dst) return out; + + memcpy(dst, reinterpret_cast(in.ptr), in.size); + out.ptr = reinterpret_cast(dst); + out.size = in.size; + return out; +} + +void NetworkDispatch::free_frame(const sizedptr &f) +{ + if (f.ptr) free_sized(f); } -network_connection_ctx* NetworkDispatch::get_context(){ - return &context; -} \ No newline at end of file +void NetworkDispatch::set_net_pid(uint16_t pid) { g_net_pid = pid; } +uint16_t NetworkDispatch::get_net_pid() const { return g_net_pid; } diff --git a/kernel/networking/network_dispatch.hpp b/kernel/networking/network_dispatch.hpp index 94cdcda7..4093389c 100644 --- a/kernel/networking/network_dispatch.hpp +++ b/kernel/networking/network_dispatch.hpp @@ -1,30 +1,48 @@ #pragma once - -#include "net/network_types.h" #include "types.h" #include "std/std.hpp" #include "drivers/net_driver.hpp" +#include "data_struct/queue.hpp" +#include "net/network_types.h" +#include "net/internet_layer/ipv4.h" class NetworkDispatch { public: NetworkDispatch(); + bool init(); - bool bind_port(uint16_t port, uint16_t process); - bool unbind_port(uint16_t port, uint16_t process); - void handle_upload_interrupt(); + void handle_download_interrupt(); - //TODO: use sizedptr - void send_packet(NetProtocol protocol, uint16_t port, network_connection_ctx *destination, void* payload, uint16_t payload_len); - bool read_packet(sizedptr *Packet, uint16_t process); + void handle_upload_interrupt(); + bool enqueue_frame(const sizedptr&); + void net_task(); + bool dequeue_packet_for(uint16_t, sizedptr*); - network_connection_ctx* get_context(); + void set_net_pid(uint16_t pid); + uint16_t get_net_pid() const; + + + const net_l2l3_endpoint& get_local_ep() const { + static net_l2l3_endpoint ep; //TODO: locking/thread safe would be good + ep = local_mac; + ep.ip = ipv4_get_cfg()->ip; + return ep; + } + + + NetDriver* driver_ptr() const { return driver; } + uint16_t header_size() const { return driver ? driver->header_size : 0; } private: - IndexMap ports; - NetDriver *driver; + static constexpr size_t QUEUE_CAPACITY = 1024; + + IndexMap ports; //port pid map + NetDriver* driver; + net_l2l3_endpoint local_mac; - NetDriver* select_driver(); + Queue tx_queue; + Queue rx_queue; - sizedptr allocate_packet(size_t size); - network_connection_ctx context; -}; \ No newline at end of file + sizedptr make_copy(const sizedptr&); + void free_frame(const sizedptr&); +}; diff --git a/kernel/networking/port_manager.c b/kernel/networking/port_manager.c new file mode 100644 index 00000000..51ededb7 --- /dev/null +++ b/kernel/networking/port_manager.c @@ -0,0 +1,97 @@ +#include "port_manager.h" +#include "types.h" +#include "networking/port_manager.h" +#include "net/internet_layer/ipv4.h" + +static port_entry_t g_port_table[PROTO_COUNT][MAX_PORTS];//tab proto/port + +static inline bool port_valid(uint16_t p) { + return p > 0 && p < MAX_PORTS; +} +static inline bool proto_valid(protocol_t proto) { + return (uint32_t)proto< PROTO_COUNT; +} + +void port_manager_init() { + for (int pr = 0; pr < PROTO_COUNT; ++pr) { + for (uint32_t p = 0; p < MAX_PORTS; ++p) { + g_port_table[pr][p].used = false; + g_port_table[pr][p].pid = PORT_FREE_OWNER; + g_port_table[pr][p].handler = NULL; + } + } +} + +int port_alloc_ephemeral(protocol_t proto, + uint16_t pid, + port_recv_handler_t handler) +{ + if (!proto_valid(proto)) return -1; + for (uint16_t p = PORT_MIN_EPHEMERAL; p <= PORT_MAX_EPHEMERAL; ++p) { + port_entry_t *e = &g_port_table[proto][p]; + if (!e->used) { + e->used = true; + e->pid = pid; + e->handler = handler; + return (int)p; + } + } + return -1; +} + +bool port_bind_manual(protocol_t proto, + uint16_t port, + uint16_t pid, + port_recv_handler_t handler) +{ + if (!proto_valid(proto) || !port_valid(port)) return false; + port_entry_t *e = &g_port_table[proto][port]; + if (e->used) return false; + e->used = true; + e->pid = pid; + e->handler = handler; + return true; +} + +bool port_unbind(protocol_t proto, + uint16_t port, + uint16_t pid) +{ + if (!proto_valid(proto) || !port_valid(port)) return false; + port_entry_t *e = &g_port_table[proto][port]; + if (!e->used || e->pid != pid) return false; + e->used = false; + e->pid = PORT_FREE_OWNER; + e->handler = NULL; + return true; +} + +void port_unbind_all(uint16_t pid) { + for (int pr = 0; pr < PROTO_COUNT; ++pr) { + for (uint16_t p = 1; p < MAX_PORTS; ++p) { + port_entry_t *e = &g_port_table[pr][p]; + if (e->used && e->pid == pid) { + e->used = false; + e->pid = PORT_FREE_OWNER; + e->handler = NULL; + } + } + } +} + +bool port_is_bound(protocol_t proto, uint16_t port) { + if (!proto_valid(proto) || !port_valid(port)) return false; + return g_port_table[proto][port].used; +} + +uint16_t port_owner_of(protocol_t proto, uint16_t port) { + if (!proto_valid(proto) || !port_valid(port)) return PORT_FREE_OWNER; + return g_port_table[proto][port].pid; +} + +port_recv_handler_t port_get_handler(protocol_t proto, uint16_t port) { + if (!proto_valid(proto) || !port_valid(port)) return NULL; + return g_port_table[proto][port].used + ? g_port_table[proto][port].handler + : NULL; +} diff --git a/kernel/networking/port_manager.h b/kernel/networking/port_manager.h new file mode 100644 index 00000000..b623d909 --- /dev/null +++ b/kernel/networking/port_manager.h @@ -0,0 +1,58 @@ +#pragma once +#include "types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define MAX_PORTS 65536 +#define PORT_MIN_EPHEMERAL 49152 +#define PORT_MAX_EPHEMERAL 65535 +#define PORT_FREE_OWNER 0xFFFF + +typedef enum { + PROTO_UDP = 0, + PROTO_TCP = 1 +} protocol_t; + +#define PROTO_COUNT 2 + +typedef void (*port_recv_handler_t)( + uintptr_t frame_ptr, + uint32_t frame_len, + uint32_t src_ip, + uint16_t src_port, + uint16_t dst_port); + +typedef struct { + uint16_t pid; + port_recv_handler_t handler; + bool used; +} port_entry_t; + +void port_manager_init(); + +int port_alloc_ephemeral(protocol_t proto, + uint16_t pid, + port_recv_handler_t handler); + +bool port_bind_manual(protocol_t proto, + uint16_t port, + uint16_t pid, + port_recv_handler_t handler); + +bool port_unbind(protocol_t proto, + uint16_t port, + uint16_t pid); + +void port_unbind_all(uint16_t pid); + +bool port_is_bound(protocol_t proto, uint16_t port); + +uint16_t port_owner_of(protocol_t proto, uint16_t port); + +port_recv_handler_t port_get_handler(protocol_t proto, uint16_t port); + +#ifdef __cplusplus +} +#endif diff --git a/kernel/networking/processes/net_proc.c b/kernel/networking/processes/net_proc.c index 81122989..8ce94e60 100644 --- a/kernel/networking/processes/net_proc.c +++ b/kernel/networking/processes/net_proc.c @@ -1,150 +1,317 @@ #include "net_proc.h" #include "kernel_processes/kprocess_loader.h" -#include "net/network_types.h" #include "process/scheduler.h" #include "console/kio.h" -#include "net/udp.h" -#include "net/dhcp.h" #include "std/memfunctions.h" +#include "std/string.h" +#include "net/internet_layer/ipv4.h" +#include "net/transport_layer/csocket_udp.h" +#include "net/application_layer/csocket_http_client.h" +#include "net/application_layer/csocket_http_server.h" +#include "net/application_layer/dhcp_daemon.h" +#include "net/network_types.h" +#include "net/link_layer/arp.h" #include "networking/network.h" -#include "syscalls/syscalls.h" -#include "math/math.h" -#include "net/tcp.h" -#include "net/http.h" -#include "net/ipv4.h" -#include "net/eth.h" +#include "net/net.h" -network_connection_ctx server; +extern uintptr_t malloc(uint64_t size); +extern void free(void *ptr, uint64_t size); +extern void sleep(uint64_t ms); -bool find_server(){ - bind_port(7777); - server = (network_connection_ctx){ - .ip = (192 << 24) | (168 << 16) | (1 << 8) | 255, - .mac = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, - .port = 8080, - }; +#define KP(fmt, ...) \ + do { kprintf(fmt, ##__VA_ARGS__); } while (0) - size_t payload_size = 5; - char hw[5] = {'h','e','l','l','o'}; +static uint32_t pick_probe_ip() { + const net_cfg_t *cfg = ipv4_get_cfg(); + if (!cfg || cfg->mode == NET_MODE_DISABLED || cfg->ip == 0) + return 0; + if (cfg->gw) + return cfg->gw; + uint32_t bcast = ipv4_broadcast(cfg->ip, cfg->mask); + if (bcast) + return bcast; + return ipv4_first_host(cfg->ip, cfg->mask); +} - send_packet(UDP, 7777, &server, hw, payload_size); +static int udp_probe_server(uint32_t probe_ip, + uint16_t probe_port, + net_l2l3_endpoint *out_l2, + net_l4_endpoint *out_l4) +{ + const net_l2l3_endpoint *local = network_get_local_endpoint(); + if (!local) + return 0; - sizedptr pack; + socket_handle_t sock = udp_socket_create(0, 0); + if (!sock) + return 0; - while (!read_packet(&pack)); + const char greeting[] = "hello"; + if (socket_sendto_udp(sock, probe_ip, probe_port, greeting, sizeof(greeting)) < 0) { + socket_destroy_udp(sock); + return 0; + } - memcpy((void*)&server.mac, (void*)eth_get_source(pack.ptr), 6); + char recv_buf[64]; + uint32_t waited = 0; + const uint32_t TIMEOUT_MS = 1000; + const uint32_t INTERVAL_MS = 50; + int64_t recvd = 0; + uint32_t resp_ip = 0; + uint16_t resp_port = 0; - server.ip = ipv4_get_source(pack.ptr + sizeof(eth_hdr_t)); + while (waited < TIMEOUT_MS) { + recvd = socket_recvfrom_udp(sock, + recv_buf, + sizeof(recv_buf), + &resp_ip, + &resp_port); + if (recvd > 0) + break; + sleep(INTERVAL_MS); + waited += INTERVAL_MS; + } - sizedptr payload = udp_parse_packet_payload(pack.ptr); + if (recvd <= 0) { + socket_close_udp(sock); + socket_destroy_udp(sock); + return 0; + } - uint8_t *content = (uint8_t*)payload.ptr; + socket_close_udp(sock); + socket_destroy_udp(sock); - kprintf("PAYLOAD: %s",(uintptr_t)string_ca_max(content, payload.size).data); + memcpy(out_l2->mac, local->mac, 6); + out_l2->ip = resp_ip; + out_l4->ip = resp_ip; + out_l4->port = resp_port; + + return resp_ip; +} + + +void free_request(HTTPRequestMsg *req) +{ + if (req->path.mem_length) + free(req->path.data, req->path.mem_length); + + for (uint32_t i = 0; i < req->extra_header_count; ++i) { + HTTPHeader *h = &req->extra_headers[i]; + if (h->key.mem_length) + free(h->key.data, h->key.mem_length); + if (h->value.mem_length) + free(h->value.data, h->value.mem_length); + } - unbind_port(7777); + if (req->extra_headers) + free(req->extra_headers, req->extra_header_count * sizeof(HTTPHeader)); - return strcmp(content, "world", false) == 0; + if (req->body.ptr && req->body.size) + free((void*)req->body.ptr, req->body.size); } -void test_network(){ - if (!find_server()){ - kprintf("Could not find update server"); +void http_server_hello_entry() +{ + uint16_t pid = get_current_proc_pid(); + http_server_handle_t srv = http_server_create(pid); + if (!srv) { + stop_current_process(); return; } - bind_port(8888); - server.port = 80; - - sizedptr http = request_http_data(GET, &server, 8888); + if (http_server_bind(srv, 80) < 0) { + http_server_destroy(srv); + stop_current_process(); + return; + } - if (http.ptr != 0){ - kprintf("Parsing payload"); - sizedptr payload = http_get_payload(http); - string content = http_get_chunked_payload(payload); - printf("Received payload %s",(uintptr_t)content.data); + if (http_server_listen(srv, 4) < 0) { + http_server_close(srv); + http_server_destroy(srv); + stop_current_process(); + return; } - unbind_port(8888); + KP("[HTTP] listening at %i.%i.%i.%i on port 80", FORMAT_IP(ipv4_get_cfg()->ip)); + + static const char HTML_ROOT[] = + "

Hello, world!

\n" + "

[Redacted]

"; + + static const char HTML_404[] = + "

404 Regrettably, no such page exists in this realm

\n" + "

Im rather inclined to deduce that your page simply does not exist. Given the state of affairs, I dare say it's not altogether surprising, innit?

"; + //comically british 404 error page + const string STR_OK = string_from_const("OK"); + const string STR_HTML = string_from_const("text/html"); + const string STR_CLOSE = string_from_const("close"); + const string STR_NOTFOUND = string_from_const("Not Found"); + + while (1) { + http_connection_handle_t conn = http_server_accept(srv); + if (!conn) + continue; + + HTTPRequestMsg req = http_server_recv_request(srv, conn); + + if (req.path.length) { + char tmp[128] = {0}; + uint32_t n = req.path.length < sizeof(tmp) - 1 + ? req.path.length + : sizeof(tmp) - 1; + memcpy(tmp, req.path.data, n); + KP("[HTTP] GET %s", tmp); + } + + HTTPResponseMsg res = {0}; + + if (req.path.length == 1 && req.path.data[0] == '/') { + res.status_code = HTTP_OK; + res.reason = STR_OK; + res.headers_common.length = sizeof(HTML_ROOT) - 1; + res.headers_common.type = STR_HTML; + res.headers_common.connection = STR_CLOSE; + res.body.ptr = (uintptr_t)HTML_ROOT; + res.body.size = sizeof(HTML_ROOT) - 1; + } + else { + res.status_code = HTTP_NOT_FOUND; + res.reason = STR_NOTFOUND; + res.headers_common.length = sizeof(HTML_404) - 1; + res.headers_common.type = STR_HTML; + res.headers_common.connection = STR_CLOSE; + res.body.ptr = (uintptr_t)HTML_404; + res.body.size = sizeof(HTML_404) - 1; + } + + http_server_send_response(srv, conn, &res); + http_connection_close(conn); + free_request(&req); + } } -uint32_t negotiate_dhcp(){ - kprintf("Sending DHCP request"); - network_connection_ctx *ctx = network_get_context(); - dhcp_request request = (dhcp_request){ - .mac = 0, - .offered_ip = 0, - .server_ip = 0, - }; - memcpy(request.mac, ctx->mac, 6); - send_packet(DHCP, 53, ctx, &request, sizeof(dhcp_request)); - - sizedptr ptr; - - dhcp_packet *payload; - - for (int i = 5; i >= 0; i--){ - while (!read_packet(&ptr));//TODO. Timeout. Opt - kprintf("Received DHCP response"); - payload = dhcp_parse_packet_payload(ptr.ptr); - uint16_t opt_index = dhcp_parse_option(payload, 53); - if (payload->options[opt_index + 2] == 2) - break; - if (i == 0) - return 60000; + + + +static void test_http(uint32_t ip) +{ + KP("[HTTP] GET %i.%i.%i.%i:80\n", + (ip >> 24) & 0xFF, + (ip >> 16) & 0xFF, + (ip >> 8) & 0xFF, + (ip ) & 0xFF); + + uint16_t pid = get_current_proc_pid(); + http_client_handle_t cli = http_client_create(pid); + if (!cli) + return; + + if (http_client_connect(cli, ip, 80) < 0) { + http_client_destroy(cli); + return; } - uint32_t local_ip = __builtin_bswap32(payload->yiaddr); - kprintf("Received local IP %i.%i.%i.%i",(local_ip >> 24) & 0xFF,(local_ip >> 16) & 0xFF,(local_ip >> 8) & 0xFF,(local_ip >> 0) & 0xFF); - request.offered_ip = payload->yiaddr; + HTTPRequestMsg req = {0}; + req.method = HTTP_METHOD_GET; + req.path = string_from_const("/"); + req.headers_common.connection = string_from_const("close"); - uint16_t serv_index = dhcp_parse_option(payload, 54); - memcpy((void*)&request.server_ip, (void*)(payload->options + serv_index + 2), min(4,payload->options[serv_index+1])); + HTTPResponseMsg resp = http_client_send_request(cli, &req); + free(req.path.data, req.path.mem_length); + free(req.headers_common.connection.data, req.headers_common.connection.mem_length); - uint16_t lease_index = dhcp_parse_option(payload, 51); - uint32_t lease_time; - memcpy((void*)&lease_time, (void*)(payload->options + lease_index + 2), min(4,payload->options[lease_index+1])); + if (resp.body.ptr && resp.body.size > 0) { + char *body_str = (char*)malloc(resp.body.size + 1); + if (body_str) { + memcpy(body_str, (void*)resp.body.ptr, resp.body.size); + body_str[resp.body.size] = '\0'; + KP("[HTTP] %i %i bytes of body%s\n", + (uint64_t)resp.status_code, + (uint64_t)resp.body.size, + body_str); + free(body_str, resp.body.size + 1); + } + } + + http_client_close(cli); + http_client_destroy(cli); + if (resp.reason.data && resp.reason.mem_length) + free(resp.reason.data, resp.reason.mem_length); + for (uint32_t i = 0; i < resp.extra_header_count; i++) { + HTTPHeader *h = &resp.extra_headers[i]; + if (h->key.mem_length) + free(h->key.data, h->key.mem_length); + if (h->value.mem_length) + free(h->value.data, h->value.mem_length); + } + if (resp.extra_headers) + free(resp.extra_headers, + resp.extra_header_count * sizeof(HTTPHeader)); +} - lease_time /= 2; +void test_network() +{ + const net_cfg_t *cfg = ipv4_get_cfg(); + net_l2l3_endpoint l2 = {0}; + net_l4_endpoint srv = {0}; - send_packet(DHCP, 53, ctx, &request, sizeof(dhcp_request)); + if (cfg && cfg->mode != NET_MODE_DISABLED && cfg->ip != 0) { + uint32_t bcast = ipv4_broadcast(cfg->ip, cfg->mask); + KP("[NET] probing broadcast %i.%i.%i.%i", + (bcast>>24)&0xFF,(bcast>>16)&0xFF, + (bcast>>8)&0xFF,(bcast&0xFF)); - for (int i = 5; i >= 0; i--){ - while (!read_packet(&ptr)); - kprintf("Received DHCP response");//TODO. Timeout. Opt - payload = dhcp_parse_packet_payload(ptr.ptr); - uint16_t opt_index = dhcp_parse_option(payload, 53); - if (payload->options[opt_index + 2] == 5) - break; - if (i == 0) - return 60000; + if (udp_probe_server(bcast, 8080, &l2, &srv)) { + test_http(srv.ip); + } + http_server_hello_entry(); + return; } - kprintf("DHCP negotiation finished. Lease %i",lease_time); + uint32_t fallback = pick_probe_ip(); + if (!fallback) + fallback = (192<<24)|(168<<16)|(1<<8)|255; + if (udp_probe_server(fallback, 8080, &l2, &srv)) { + test_http(srv.ip); + } else { + KP("[NET] could not find update server\n"); + } +} - //We can parse options for - //DHCP Server identifier - //Subnet mask - //Router (3) - //DNS (8 bytes) (6) - //TODO: Make subsequent DHCP requests (renewals and requests) directed to the server - ctx->ip = local_ip; +void net_test_entry(){ test_network(); - return lease_time; + stop_current_process(); } -void dhcp_daemon(){ - bind_port(68); - while (true){ - uint32_t await = negotiate_dhcp(); - if (await == 0) break; - kprintf("DHCP Negotiated for %i",await); - sleep(await); +void ip_waiter_entry() +{ + for (;;) { + const net_cfg_t *cfg = ipv4_get_cfg(); + if (cfg && cfg->mode != NET_MODE_DISABLED && cfg->ip != 0) { + create_kernel_process("net_test", net_test_entry); + break; + } + sleep(200); } - bind_port(68); stop_current_process(); } -process_t* launch_net_process(){ - return create_kernel_process("dhcp_daemon",dhcp_daemon); +process_t* launch_net_process() +{ + const net_cfg_t *cfg = ipv4_get_cfg(); + + process_t* net = create_kernel_process("net_net", network_net_task_entry); + network_net_set_pid(net ? net->id : 0xFFFF); + + process_t* arp = create_kernel_process("arp_daemon", arp_daemon_entry); + arp_set_pid(arp ? arp->id : 0xFFFF); + + if (cfg && cfg->mode != NET_MODE_DISABLED && cfg->ip != 0) { + create_kernel_process("net_test", net_test_entry); + return NULL; + } + + process_t* dhcp = create_kernel_process("dhcp_daemon", dhcp_daemon_entry); + dhcp_set_pid(dhcp ? dhcp->id : 0xFFFF); + create_kernel_process("ip_waiter", ip_waiter_entry); + return dhcp; } diff --git a/kernel/networking/processes/net_proc.h b/kernel/networking/processes/net_proc.h index 0a6e84ed..81c2921c 100644 --- a/kernel/networking/processes/net_proc.h +++ b/kernel/networking/processes/net_proc.h @@ -1,5 +1,14 @@ #pragma once - #include "process/process.h" -process_t* launch_net_process(); \ No newline at end of file +process_t* launch_net_process(); + +#ifdef __cplusplus +extern "C" { +#endif + +void test_network(); + +#ifdef __cplusplus +} +#endif diff --git a/kernel/process/scheduler.c b/kernel/process/scheduler.c index d7344adf..f8b9f06a 100644 --- a/kernel/process/scheduler.c +++ b/kernel/process/scheduler.c @@ -7,7 +7,7 @@ #include "input/input_dispatch.h" #include "exceptions/exception_handler.h" #include "exceptions/timer.h" - +#include "console/kconsole/kconsole.h" extern void save_context(process_t* proc); extern void save_pc_interrupt(process_t* proc); extern void restore_context(process_t* proc); @@ -63,6 +63,7 @@ void process_restore(){ } void start_scheduler(){ + kconsole_clear(); disable_interrupt(); timer_init(1); switch_proc(YIELD); @@ -204,20 +205,24 @@ void sleep_process(uint64_t msec){ } void wake_processes(){ - uint16_t removed = 0; - uint64_t new_wake_time = 0; - for (uint16_t i = 0; i < sleep_count; i++){ - uint64_t wake_time = sleeping[i].timestamp + sleeping[i].sleep_time; - if (wake_time <= timer_now_msec()){ - process_t *proc = get_proc_by_pid(sleeping[i].pid); - proc->state = READY; - sleeping[i].valid = false; - removed++; - } else if (new_wake_time == 0 || wake_time < new_wake_time){ - new_wake_time = wake_time; + uint64_t now = timer_now_msec(); + uint64_t next = UINT64_MAX; + uint16_t w = 0; + for(uint16_t i=0;istate = READY; + }else{ + if(wake < next) next = wake; + sleeping[w++] = sleeping[i]; } } - sleep_count -= removed; - virtual_timer_reset(timer_now_msec() - new_wake_time); - virtual_timer_enable(); + sleep_count = w; + + if(next != UINT64_MAX){ + virtual_timer_reset(next - now); + virtual_timer_enable(); + } } \ No newline at end of file diff --git a/kernel/process/syscall.c b/kernel/process/syscall.c index 0f72ef79..8882644a 100644 --- a/kernel/process/syscall.c +++ b/kernel/process/syscall.c @@ -13,6 +13,7 @@ #include "std/string.h" #include "exceptions/timer.h" #include "networking/network.h" +#include "networking/port_manager.h" void sync_el0_handler_c(){ save_context_registers(); @@ -131,24 +132,34 @@ void sync_el0_handler_c(){ result = timer_now_msec(); break; - case 51: - result = network_bind_port(x0, get_current_proc_pid()); + case 51: { //bind + uint16_t port = (uint16_t)x0; + port_recv_handler_t handler = (port_recv_handler_t)x1; + protocol_t proto = (protocol_t)x2; + uint16_t pid = get_current_proc_pid(); + result = port_bind_manual(port, pid, proto, handler); break; + } - case 52: - result = network_unbind_port(x0, get_current_proc_pid()); + case 52: { //unbind + uint16_t port = (uint16_t)x0; + protocol_t proto = (protocol_t)x2; + uint16_t pid = get_current_proc_pid(); + result = port_unbind(port, proto, pid); break; + } - case 53: - network_connection_ctx *ctx = (network_connection_ctx*)x2; - void* payload = (void*)x3; - network_send_packet(x0, x1, ctx, payload, x4); + case 53: { //net_tx_frame + uintptr_t frame_ptr = x0; + uint32_t frame_len = (uint32_t)x1; + result = net_tx_frame(frame_ptr, frame_len); break; - - case 54: - sizedptr *ptr = (sizedptr*)x0; - result = network_read_packet_current(ptr); + } + case 54: { //net_rx_frame + sizedptr *user_out = (sizedptr*)x0; + result = net_rx_frame(user_out); break; + } default: handle_exception_with_info("Unknown syscall", iss); @@ -165,7 +176,8 @@ void sync_el0_handler_c(){ stop_current_process(); } } - save_syscall_return(result); + if (result > 0) + save_syscall_return(result); process_restore(); } diff --git a/run_raspi b/run_raspi index 4051057c..aa0b17eb 100755 --- a/run_raspi +++ b/run_raspi @@ -32,4 +32,4 @@ $PRIVILEGE qemu-system-aarch64 \ -serial mon:stdio \ -device usb-kbd \ -d guest_errors \ -$ARGS \ No newline at end of file +$ARGS diff --git a/run_virt b/run_virt index b1016e9b..72f84b52 100755 --- a/run_virt +++ b/run_virt @@ -20,18 +20,29 @@ OS_TYPE="$(uname)" DISPLAY_MODE="default" SELECTED_GPU="virtio-gpu-pci" +DUMP="" + if [ "$OS_TYPE" = "Darwin" ]; then - NETARG="vmnet-bridged,id=net0,ifname=en0" + NETDEV="-netdev vmnet-bridged,id=net0,ifname=en0" PRIVILEGE="sudo" DISPLAY_MODE="sdl" elif [ "$OS_TYPE" = "Linux" ]; then - NETARG="user,id=net0" + NETDEV="-netdev user,id=net0" PRIVILEGE="" else echo "Unknown OS: $OS_TYPE" >&2 exit 1 fi +if [ -d /sys/class/net/tap0 ] && [ -d /sys/class/net/br0 ]; then + #tap bridge + NETDEV="-netdev tap,id=net0,ifname=tap0,script=no,downscript=no,vnet_hdr=off" + PRIVILEGE="" + DUMP="-object filter-dump,id=f0,netdev=net0,file=/tmp/virtio.pcap" +fi + +echo "Using networking mode: $NETDEV" + $PRIVILEGE qemu-system-aarch64 \ -M virt \ -cpu cortex-a72 \ @@ -39,8 +50,9 @@ $PRIVILEGE qemu-system-aarch64 \ -kernel kernel.elf \ -device $SELECTED_GPU \ -display $DISPLAY_MODE \ - -netdev $NETARG \ - -device virtio-net-pci,netdev=net0 \ + $NETDEV \ + -device virtio-net-pci,netdev=net0,mac=52:54:00:12:34:56 \ + $DUMP \ -serial mon:stdio \ -drive file=disk.img,if=none,format=raw,id=hd0 \ -device virtio-blk-pci,drive=hd0 \ @@ -49,4 +61,4 @@ $PRIVILEGE qemu-system-aarch64 \ -device virtio-sound-pci,audiodev=sdl_audio \ -audiodev sdl,id=sdl_audio \ -d guest_errors \ - $ARGS \ No newline at end of file + $ARGS diff --git a/rundebug b/rundebug index 48b8dfb2..965618f0 100755 --- a/rundebug +++ b/rundebug @@ -1,6 +1,6 @@ #!/bin/bash -MODE="virt" +MODE="raspi" ARGS=() for arg in "$@"; do @@ -10,5 +10,9 @@ for arg in "$@"; do esac done -./run_$MODE debug & -./debug ${ARGS[*]} \ No newline at end of file +osascript < -class LinkedList { +class DoubleLinkedList { private: struct Node { T data; @@ -32,16 +32,24 @@ class LinkedList { free(n, sizeof(Node)); } - static void swap(LinkedList& a, LinkedList& b) noexcept { - std::swap(a.head, b.head); - std::swap(a.tail, b.tail); - std::swap(a.length, b.length); + static void swap(DoubleLinkedList& a, DoubleLinkedList& b) noexcept { + Node* tmpHead = a.head; + a.head = b.head; + b.head = tmpHead; + + Node* tmpTail = a.tail; + a.tail = b.tail; + b.tail = tmpTail; + + size_t tmpLen = a.length; + a.length = b.length; + b.length = tmpLen; } public: - LinkedList() : head(nullptr), tail(nullptr), length(0) {} + DoubleLinkedList() : head(nullptr), tail(nullptr), length(0) {} - LinkedList(const LinkedList& other) : head(nullptr), tail(nullptr), length(0) { + DoubleLinkedList(const DoubleLinkedList& other) : head(nullptr), tail(nullptr), length(0) { if (other.head) { Node* it = other.head; do { @@ -51,13 +59,13 @@ class LinkedList { } } - ~LinkedList() { + ~DoubleLinkedList() { while (!empty()) pop_front(); } - LinkedList& operator=(const LinkedList& other) { + DoubleLinkedList& operator=(const DoubleLinkedList& other) { if (this != &other) { - LinkedList tmp(other); + DoubleLinkedList tmp(other); swap(*this, tmp); } return *this; diff --git a/shared/data_struct/linked_list.c b/shared/data_struct/linked_list.c index 0ffdd600..67413393 100644 --- a/shared/data_struct/linked_list.c +++ b/shared/data_struct/linked_list.c @@ -1,6 +1,6 @@ #include "linked_list.h" -clinkedlist_t *clinkedlist_create(void){ +clinkedlist_t *clinkedlist_create(){ uintptr_t mem = malloc(sizeof(clinkedlist_t)); if((void *)mem == NULL) return NULL; clinkedlist_t *list = (clinkedlist_t *)mem; diff --git a/shared/data_struct/linked_list.h b/shared/data_struct/linked_list.h index 4e27a8eb..c0c32815 100644 --- a/shared/data_struct/linked_list.h +++ b/shared/data_struct/linked_list.h @@ -19,7 +19,7 @@ typedef struct clinkedlist { extern uintptr_t malloc(uint64_t size); extern void free(void *ptr, uint64_t size); -clinkedlist_t *clinkedlist_create(void); +clinkedlist_t *clinkedlist_create(); void clinkedlist_destroy(clinkedlist_t *list); clinkedlist_t *clinkedlist_clone(const clinkedlist_t *list); void clinkedlist_push_front(clinkedlist_t *list, void *data); diff --git a/shared/net/application_layer/csocket_http_client.cpp b/shared/net/application_layer/csocket_http_client.cpp new file mode 100644 index 00000000..baac959b --- /dev/null +++ b/shared/net/application_layer/csocket_http_client.cpp @@ -0,0 +1,54 @@ +#pragma once +#include "csocket_http_client.h" +#include "socket_http_client.hpp" +#include "net/transport_layer/socket_tcp.hpp" + +extern "C" { + extern uintptr_t malloc(uint64_t size); + extern void free(void *ptr, uint64_t size); + extern void sleep(uint64_t ms); +} + +extern "C" { + +http_client_handle_t http_client_create(uint16_t pid) { + uintptr_t mem = malloc(sizeof(HTTPClient)); + if (!mem) return NULL; + HTTPClient *cli = reinterpret_cast( (void*)mem ); + return reinterpret_cast(new HTTPClient(pid)); +} + +void http_client_destroy(http_client_handle_t h) { + if (!h) return; + HTTPClient *cli = reinterpret_cast(h); + cli->~HTTPClient(); + free(cli, sizeof(HTTPClient)); +} + +int32_t http_client_connect(http_client_handle_t h, + uint32_t ip, + uint16_t port) +{ + if (!h) return (int32_t)SOCK_ERR_INVAL; + HTTPClient *cli = reinterpret_cast(h); + return cli->connect(ip, port); +} + +HTTPResponseMsg http_client_send_request(http_client_handle_t h, + const HTTPRequestMsg *req) +{ + HTTPResponseMsg empty; + if (!h || !req) { + empty.status_code = (HttpError)SOCK_ERR_INVAL; + return empty; + } + HTTPClient *cli = reinterpret_cast(h); + return cli->send_request(*req); +} + +int32_t http_client_close(http_client_handle_t h) { + if (!h) return (int32_t)SOCK_ERR_INVAL; + HTTPClient *cli = reinterpret_cast(h); + return cli->close(); +} +} diff --git a/shared/net/application_layer/csocket_http_client.h b/shared/net/application_layer/csocket_http_client.h new file mode 100644 index 00000000..f3996578 --- /dev/null +++ b/shared/net/application_layer/csocket_http_client.h @@ -0,0 +1,27 @@ +#pragma once + +#include "http.h" +#include "std/string.h" +#include "std/memfunctions.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void* http_client_handle_t; + +http_client_handle_t http_client_create(uint16_t pid); +void http_client_destroy(http_client_handle_t h); + +int32_t http_client_connect(http_client_handle_t h, + uint32_t ip, + uint16_t port); + +HTTPResponseMsg http_client_send_request(http_client_handle_t h, + const HTTPRequestMsg *req); + +int32_t http_client_close(http_client_handle_t h); + +#ifdef __cplusplus +} +#endif diff --git a/shared/net/application_layer/csocket_http_server.cpp b/shared/net/application_layer/csocket_http_server.cpp new file mode 100644 index 00000000..7398cf43 --- /dev/null +++ b/shared/net/application_layer/csocket_http_server.cpp @@ -0,0 +1,81 @@ +#include "csocket_http_server.h" +#include "socket_http_server.hpp" + +extern "C" { + extern uintptr_t malloc(uint64_t size); + extern void free(void *ptr, uint64_t size); + extern void sleep(uint64_t ms); +} + +extern "C" { + +http_server_handle_t http_server_create(uint16_t pid) { + void* raw = (void*)malloc(sizeof(HTTPServer)); + if (!raw) return nullptr; + HTTPServer* srv = reinterpret_cast(raw); + return reinterpret_cast(new HTTPServer(pid)); +} + +void http_server_destroy(http_server_handle_t h) { + if (!h) return; + HTTPServer* srv = reinterpret_cast(h); + srv->~HTTPServer(); + free(srv, sizeof(HTTPServer)); +} + +int32_t http_server_bind(http_server_handle_t h, uint16_t port) { + if (!h) return (int32_t)SOCK_ERR_INVAL; + HTTPServer* srv = reinterpret_cast(h); + return srv->bind(port); +} + +int32_t http_server_listen(http_server_handle_t h, int backlog) { + if (!h) return (int32_t)SOCK_ERR_INVAL; + HTTPServer* srv = reinterpret_cast(h); + return srv->listen(backlog); +} + +http_connection_handle_t http_server_accept(http_server_handle_t h) { + if (!h) return nullptr; + HTTPServer* srv = reinterpret_cast(h); + TCPSocket* cli = srv->accept(); + return reinterpret_cast(cli); +} + +HTTPRequestMsg http_server_recv_request(http_server_handle_t h, + http_connection_handle_t c) +{ + HTTPRequestMsg empty{}; + if (!h || !c) { + return empty; + } + HTTPServer* srv = reinterpret_cast(h); + TCPSocket* conn = reinterpret_cast(c); + return srv->recv_request(conn); +} + +int32_t http_server_send_response(http_server_handle_t h, + http_connection_handle_t c, + const HTTPResponseMsg* res) +{ + if (!h || !c || !res) return (int32_t)SOCK_ERR_INVAL; + HTTPServer* srv = reinterpret_cast(h); + TCPSocket* conn = reinterpret_cast(c); + return srv->send_response(conn, *res); +} + +int32_t http_connection_close(http_connection_handle_t c) { + if (!c) return (int32_t)SOCK_ERR_INVAL; + TCPSocket* conn = reinterpret_cast(c); + int32_t r = conn->close(); + delete conn; + return r; +} + +int32_t http_server_close(http_server_handle_t h) { + if (!h) return (int32_t)SOCK_ERR_INVAL; + HTTPServer* srv = reinterpret_cast(h); + return srv->close(); +} + +} diff --git a/shared/net/application_layer/csocket_http_server.h b/shared/net/application_layer/csocket_http_server.h new file mode 100644 index 00000000..e46d221e --- /dev/null +++ b/shared/net/application_layer/csocket_http_server.h @@ -0,0 +1,37 @@ +#pragma once +#include "types.h" +#include "http.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void* http_server_handle_t; +typedef void* http_connection_handle_t; + +http_server_handle_t http_server_create(uint16_t pid); + +void http_server_destroy(http_server_handle_t srv); + +int32_t http_server_bind(http_server_handle_t srv, + uint16_t port); + +int32_t http_server_listen(http_server_handle_t srv, + int backlog); + +http_connection_handle_t http_server_accept(http_server_handle_t srv); + +HTTPRequestMsg http_server_recv_request(http_server_handle_t srv, + http_connection_handle_t conn); + +int32_t http_server_send_response(http_server_handle_t srv, + http_connection_handle_t conn, + const HTTPResponseMsg* res); + +int32_t http_connection_close(http_connection_handle_t conn); + +int32_t http_server_close(http_server_handle_t srv); + +#ifdef __cplusplus +} +#endif diff --git a/shared/net/application_layer/dhcp.c b/shared/net/application_layer/dhcp.c new file mode 100644 index 00000000..431292df --- /dev/null +++ b/shared/net/application_layer/dhcp.c @@ -0,0 +1,76 @@ +#include "dhcp.h" +#include "std/memfunctions.h" +#include "net/transport_layer/udp.h" +#include "net/internet_layer/ipv4.h" +#include "types.h" +#include "net/transport_layer/csocket_udp.h" + +static socket_handle_t g_dhcp_socket = NULL; + +extern uintptr_t malloc(uint64_t size); +extern void free(void *ptr, uint64_t size); +extern void sleep(uint64_t ms); +sizedptr dhcp_build_packet(const dhcp_request *req, + uint8_t msg_type, + uint32_t xid) +{ + dhcp_packet p; + memset(&p, 0, sizeof(p)); + size_t idx = 0; + + p.op = 1; p.htype = 1; p.hlen = 6; p.hops = 0; + p.xid = xid; p.secs = 0; + p.flags = __builtin_bswap16(0x8000); + p.ciaddr = 0; p.yiaddr = 0; p.siaddr = 0; p.giaddr = 0; + memcpy(p.chaddr, req->mac, 6); + + p.options[idx++] = 0x63; p.options[idx++] = 0x82; + p.options[idx++] = 0x53; p.options[idx++] = 0x63; + + p.options[idx++] = 53; p.options[idx++] = 1; + p.options[idx++] = msg_type; + + if (msg_type == DHCPREQUEST) { + p.options[idx++] = 50; p.options[idx++] = 4; + memcpy(&p.options[idx], &req->offered_ip, 4); idx += 4; + if (req->server_ip) { + p.options[idx++] = 54; p.options[idx++] = 4; + memcpy(&p.options[idx], &req->server_ip, 4); idx += 4; + } + } + + p.options[idx++] = 255; + + size_t dhcp_len = sizeof(dhcp_packet) - (sizeof(p.options) - idx); + + uintptr_t buf = malloc(dhcp_len); + memcpy((void*)buf, &p, dhcp_len); + + return (sizedptr){ .ptr = buf, .size = (uint32_t)dhcp_len }; +} + +dhcp_packet* dhcp_parse_frame_payload(uintptr_t frame_ptr) { + return (dhcp_packet*)frame_ptr; +} + +uint16_t dhcp_parse_option(const dhcp_packet *p, uint16_t wanted) { + const uint8_t *opt = p->options; + size_t i= 4; + while (i < sizeof(p->options)) { + uint8_t code = opt[i++]; + if (code == 0) continue; + if (code == 255) break; + if (i >= sizeof(p->options)) break; + uint8_t len = opt[i++]; + if (code == wanted) { + return (uint16_t)(i - 2); + } + i += len; + } + return UINT16_MAX; +} + +uint8_t dhcp_option_len(const dhcp_packet *p, uint16_t idx) { + if (idx == 0 || idx + 1 >= sizeof(p->options)) return 0; + return p->options[idx+1]; +} diff --git a/shared/net/application_layer/dhcp.h b/shared/net/application_layer/dhcp.h new file mode 100644 index 00000000..90c29697 --- /dev/null +++ b/shared/net/application_layer/dhcp.h @@ -0,0 +1,62 @@ +#pragma once + +#include "types.h" +#include "net/network_types.h" +#include "net/link_layer/eth.h" +#include "net/internet_layer/ipv4.h" +#include "net/transport_layer/udp.h" + +#ifdef __cplusplus +extern "C" { +#endif + +enum { + DHCPDISCOVER = 1, + DHCPOFFER = 2, + DHCPREQUEST = 3, + DHCPDECLINE = 4, + DHCPACK = 5, + DHCPNAK = 6, + DHCPRELEASE = 7, + DHCPINFORM = 8 +}; + +#define DHCP_FRAME_MAX ( sizeof(eth_hdr_t) + sizeof(ipv4_hdr_t) + sizeof(udp_hdr_t) + sizeof(dhcp_packet) ) + +typedef struct __attribute__((packed)) { + uint8_t op; + uint8_t htype; + uint8_t hlen; + uint8_t hops; + uint32_t xid; + uint16_t secs; + uint16_t flags; + uint32_t ciaddr; + uint32_t yiaddr; + uint32_t siaddr; + uint32_t giaddr; + uint8_t chaddr[16]; + uint8_t sname[64]; + uint8_t file[128]; + uint8_t options[312]; +} dhcp_packet; + +typedef struct { + uint8_t mac[6]; + uint32_t server_ip; + uint32_t offered_ip; +} dhcp_request; + +sizedptr dhcp_build_packet(const dhcp_request *req, + uint8_t msg_type, + uint32_t xid); + +dhcp_packet* dhcp_parse_frame_payload(uintptr_t frame_ptr); + +uint16_t dhcp_parse_option(const dhcp_packet *p, uint16_t wanted); + +uint8_t dhcp_option_len(const dhcp_packet *p, uint16_t idx); + +#ifdef __cplusplus +} +#endif diff --git a/shared/net/application_layer/dhcp_daemon.c b/shared/net/application_layer/dhcp_daemon.c new file mode 100644 index 00000000..9fcc99c4 --- /dev/null +++ b/shared/net/application_layer/dhcp_daemon.c @@ -0,0 +1,368 @@ +#include "dhcp_daemon.h" + +#include "console/kio.h" +#include "std/memfunctions.h" +#include "process/scheduler.h" +#include "math/math.h" +#include "math/rng.h" + +#include "networking/network.h" +#include "net/application_layer/dhcp.h" +#include "net/internet_layer/ipv4.h" +#include "net/network_types.h" +#include "net/link_layer/arp.h" + +#include "net/transport_layer/csocket_udp.h" + +#include "../net.h" + +extern void sleep(uint64_t ms); +extern uintptr_t malloc(uint64_t size); +extern void free(void *ptr, uint64_t size); + +#ifndef SOCK_ROLE_SERVER +#define SOCK_ROLE_SERVER 1 +#endif + +#define DHCPDISCOVER 1 +#define DHCPOFFER 2 +#define DHCPREQUEST 3 +#define DHCPDECLINE 4 +#define DHCPACK 5 +#define DHCPNAK 6 +#define DHCPRELEASE 7 +#define DHCPINFORM 8 + +typedef enum { + DHCP_S_INIT = 0, + DHCP_S_SELECTING, + DHCP_S_REQUESTING, + DHCP_S_BOUND, + DHCP_S_RENEWING, + DHCP_S_REBINDING +} dhcp_state_t; + +#define KP(fmt, ...) \ + do { kprintf(fmt, ##__VA_ARGS__); } while(0) +static dhcp_state_t g_state = DHCP_S_INIT; +static net_l2l3_endpoint g_local_ep = {0}; +static volatile bool g_force_renew = false; +static uint32_t g_t1_left_ms = 0; +static uint32_t g_t2_left_ms = 0; +static uint16_t g_pid_dhcpd = 0xFFFF; + +static socket_handle_t g_sock = 0; + +uint16_t get_dhcp_pid() { return g_pid_dhcpd; } +bool dhcp_is_running() { return g_pid_dhcpd != 0xFFFF; } +void dhcp_set_pid(uint16_t p){ g_pid_dhcpd = p; } +void dhcp_force_renew() { g_force_renew = true; } + +static inline uint32_t rd_be32(const uint8_t* p){ + uint32_t v; memcpy(&v, p, 4); return __builtin_bswap32(v); +} + +static void log_state_change(dhcp_state_t old, dhcp_state_t now){ + KP("[DHCP] state %i -> %i", old, now); +} + +static void dhcp_apply_offer(dhcp_packet *p, dhcp_request *req, uint32_t xid); + +static void dhcp_tx_packet(const dhcp_request *req, + uint8_t msg_type, + uint32_t xid, + uint32_t dst_ip) +{ + sizedptr pkt = dhcp_build_packet(req, msg_type, xid); + socket_sendto_udp(g_sock, dst_ip, 67, (const void*)pkt.ptr, pkt.size); + free((void*)pkt.ptr, pkt.size); +} + +static void dhcp_send_discover(uint32_t xid){ + KP("[DHCP] discover xid=%i", xid); + dhcp_request req = {0}; + memcpy(req.mac, g_local_ep.mac, 6); + dhcp_tx_packet(&req, DHCPDISCOVER, xid, 0xFFFFFFFFu); +} + +static void dhcp_send_request(const dhcp_request *req, + uint32_t xid, + bool broadcast) +{ + uint32_t dst = broadcast ? 0xFFFFFFFFu : __builtin_bswap32(req->server_ip); + //KP("[DHCP] request xid=%i dst=%x\n", (uint64_t)xid, (uint64_t)dst); + dhcp_tx_packet(req, DHCPREQUEST, xid, dst); +} + +static void dhcp_send_renew(uint32_t xid) { + const net_cfg_t *cfg = ipv4_get_cfg(); + dhcp_request req = {0}; + memcpy(req.mac, g_local_ep.mac, 6); + req.offered_ip = __builtin_bswap32(cfg->ip); + req.server_ip = cfg->rt ? cfg->rt->server_ip : 0; + uint32_t dst = req.server_ip ? __builtin_bswap32(req.server_ip) : 0xFFFFFFFFu; + KP("[DHCP] renew xid=%i dst=%x", xid, dst); + dhcp_tx_packet(&req, DHCPREQUEST, xid, dst); +} + +static void dhcp_send_rebind(uint32_t xid) { + const net_cfg_t *cfg = ipv4_get_cfg(); + dhcp_request req = {0}; + memcpy(req.mac, g_local_ep.mac, 6); + req.offered_ip = __builtin_bswap32(cfg->ip); + req.server_ip = 0; + KP("[DHCP] rebind xid=%i", xid); + dhcp_tx_packet(&req, DHCPREQUEST, xid, 0xFFFFFFFFu); +} + +static bool dhcp_wait_for_type(uint8_t wanted, + dhcp_packet **outp, + sizedptr *outsp, + uint32_t timeout_ms) +{ + uint32_t waited = 0; + while(waited < timeout_ms){ + uint8_t buf[1024]; + uint32_t sip; uint16_t sport; + int64_t r = socket_recvfrom_udp(g_sock, buf, sizeof(buf), &sip, &sport); + if(r > 0){ + dhcp_packet *p = (dhcp_packet*)buf; + uint16_t idx= dhcp_parse_option(p, 53); + if (idx != UINT16_MAX && p->options[idx + 2] == wanted){ + uintptr_t copy = malloc((uint32_t)r); + memcpy((void*)copy, buf, (size_t)r); + if (outp) *outp= (dhcp_packet*)copy; + if (outsp) *outsp = (sizedptr){ copy, (uint32_t)r }; + return true; + } + } else { + sleep(50); + waited += 50; + } + } + KP("[DHCP] wait timeout type=%i", wanted); + return false; +} + +static void dhcp_fsm_once() +{ + //TODO: use a syscall for the rng + rng_t rng; + rng_init_random(&rng); + uint32_t xid_seed = rng_next32(&rng); + dhcp_state_t old = g_state; + + switch (g_state) { + + case DHCP_S_INIT: { + const net_l2l3_endpoint *le = network_get_local_endpoint(); + memcpy(g_local_ep.mac, le->mac, 6); + g_local_ep.ip = 0; + xid_seed += 0x1111; + dhcp_send_discover(xid_seed); + g_state = DHCP_S_SELECTING; + } break; + + case DHCP_S_SELECTING: { + dhcp_packet *offer = NULL; sizedptr sp = {0}; + if (!dhcp_wait_for_type(DHCPOFFER, &offer, &sp, 5000)) { + g_state = DHCP_S_INIT; + break; + } + dhcp_request req = {0}; + memcpy(req.mac, g_local_ep.mac, 6); + dhcp_apply_offer(offer, &req, xid_seed); + free((void*)sp.ptr, sp.size); + + xid_seed += 0x0101; + dhcp_send_request(&req, xid_seed, true); + g_state = DHCP_S_REQUESTING; + } break; + + case DHCP_S_REQUESTING: { + dhcp_packet *ack = NULL; sizedptr sp = {0}; + if (!dhcp_wait_for_type(DHCPACK, &ack, &sp, 5000)) { + g_state = DHCP_S_INIT; + } else { + dhcp_request dummy = {0}; + memcpy(dummy.mac, g_local_ep.mac, 6); + dhcp_apply_offer(ack, &dummy, xid_seed); + free((void*)sp.ptr, sp.size); + g_state = DHCP_S_BOUND; + } + } break; + + case DHCP_S_BOUND: { + if (g_force_renew) { + g_force_renew = false; + xid_seed += 0x2222; + dhcp_send_renew(xid_seed); + g_state = DHCP_S_RENEWING; + + } else if (g_t2_left_ms == 0) { + xid_seed += 0x3333; + dhcp_send_rebind(xid_seed); + g_state = DHCP_S_REBINDING; + + } else if (g_t1_left_ms == 0) { + xid_seed += 0x2222; + dhcp_send_renew(xid_seed); + g_state = DHCP_S_RENEWING; + } + } break; + + case DHCP_S_RENEWING: { + dhcp_packet *p = NULL; sizedptr sp = {0}; + if (dhcp_wait_for_type(DHCPACK, &p, &sp, 2000)) { + dhcp_request dummy = {0}; + memcpy(dummy.mac, g_local_ep.mac, 6); + dhcp_apply_offer(p, &dummy, xid_seed); + free((void*)sp.ptr, sp.size); + g_state = DHCP_S_BOUND; + } else { + xid_seed += 0x3333; + dhcp_send_rebind(xid_seed); + g_state = DHCP_S_REBINDING; + } + } break; + + case DHCP_S_REBINDING: { + dhcp_packet *p = NULL; sizedptr sp = {0}; + if (dhcp_wait_for_type(DHCPACK, &p, &sp, 2000)) { + dhcp_request dummy = {0}; + memcpy(dummy.mac, g_local_ep.mac, 6); + dhcp_apply_offer(p, &dummy, xid_seed); + free((void*)sp.ptr, sp.size); + g_state = DHCP_S_BOUND; + } else { + net_cfg_t g_net_cfg; + g_net_cfg.ip = 0; + g_net_cfg.mode = NET_MODE_DISABLED; + ipv4_set_cfg(&g_net_cfg); + g_state = DHCP_S_INIT; + } + } break; + } + + if (old != g_state) log_state_change(old, g_state); +} + +void dhcp_daemon_entry(){ + KP("[DHCP] daemon start pid=%i", get_current_proc_pid()); + g_pid_dhcpd = (uint16_t)get_current_proc_pid(); + g_sock = udp_socket_create(SOCK_ROLE_SERVER, g_pid_dhcpd); + if(socket_bind_udp(g_sock, 68) != 0){ + KP("[DHCP] bind failed\n"); + return; + } + + for(;;){ + dhcp_fsm_once(); + sleep(100); + + if(g_state == DHCP_S_BOUND){ + if(g_t1_left_ms > 100) g_t1_left_ms -= 100; else g_t1_left_ms = 0; + if(g_t2_left_ms > 100) g_t2_left_ms -= 100; else g_t2_left_ms = 0; + } + } +} + +static void dhcp_apply_offer(dhcp_packet *p, dhcp_request *req, uint32_t xid) { + const net_cfg_t *current = ipv4_get_cfg(); + net_cfg_t cfg_local = *current; + static net_runtime_opts_t rt_static; + memset(&rt_static, 0, sizeof(rt_static)); + cfg_local.rt = &rt_static; + cfg_local.rt->xid = (uint16_t)xid; + cfg_local.mode = NET_MODE_DHCP; + + uint32_t yi_net = p->yiaddr; + cfg_local.ip = __builtin_bswap32(yi_net); + req->offered_ip = yi_net; + + uint16_t idx; + uint8_t len; + + idx = dhcp_parse_option(p, 1); + if (idx != UINT16_MAX && (len = p->options[idx+1]) >= 4) { + uint32_t mask_net; + memcpy(&mask_net, &p->options[idx+2], 4); + cfg_local.mask = __builtin_bswap32(mask_net); + } + + idx = dhcp_parse_option(p, 3); + if (idx != UINT16_MAX && (len = p->options[idx+1]) >= 4) { + uint32_t gw_net; + memcpy(&gw_net, &p->options[idx+2], 4); + cfg_local.gw = __builtin_bswap32(gw_net); + } + + idx = dhcp_parse_option(p, 6); + if (idx != UINT16_MAX) { + len = p->options[idx+1]; + for (int i = 0; i < 2 && (i*4 + 4) <= len; ++i) { + uint32_t dns_net; + memcpy(&dns_net, &p->options[idx+2 + i*4], 4); + cfg_local.rt->dns[i] = __builtin_bswap32(dns_net); + } + } + + idx = dhcp_parse_option(p, 42); + if (idx != UINT16_MAX) { + len = p->options[idx+1]; + for (int i = 0; i < 2 && (i*4 + 4) <= len; ++i) { + uint32_t ntp_net; + memcpy(&ntp_net, &p->options[idx+2 + i*4], 4); + cfg_local.rt->ntp[i] = __builtin_bswap32(ntp_net); + } + } + + idx = dhcp_parse_option(p, 26); + if (idx != UINT16_MAX && p->options[idx+1] == 2) { + uint16_t mtu_net; + memcpy(&mtu_net, &p->options[idx+2], 2); + cfg_local.rt->mtu = __builtin_bswap16(mtu_net); + } + + idx = dhcp_parse_option(p, 51); + if (idx != UINT16_MAX && p->options[idx+1] >= 4) { + uint32_t lease_net; + memcpy(&lease_net, &p->options[idx+2], 4); + cfg_local.rt->lease = __builtin_bswap32(lease_net); + } + + idx = dhcp_parse_option(p, 58); + if (idx != UINT16_MAX && p->options[idx+1] >= 4) { + uint32_t t1_net; + memcpy(&t1_net, &p->options[idx+2], 4); + cfg_local.rt->t1 = __builtin_bswap32(t1_net); + } else { + cfg_local.rt->t1 = cfg_local.rt->lease / 2; + } + idx = dhcp_parse_option(p, 59); + if (idx != UINT16_MAX && p->options[idx+1] >= 4) { + uint32_t t2_net; + memcpy(&t2_net, &p->options[idx+2], 4); + cfg_local.rt->t2 = __builtin_bswap32(t2_net); + } else { + cfg_local.rt->t2 = cfg_local.rt->t1 * 2; + } + + idx = dhcp_parse_option(p, 54); + if (idx != UINT16_MAX && p->options[idx+1] >= 4) { + uint32_t srv_net; + memcpy(&srv_net, &p->options[idx+2], 4); + cfg_local.rt->server_ip = __builtin_bswap32(srv_net); + req->server_ip = srv_net; + } + uint32_t bcast = ipv4_broadcast(cfg_local.ip, cfg_local.mask); + static const uint8_t bmac[6] = {0xFF,0xFF,0xFF,0xFF,0xFF,0xFF}; + arp_table_put(bcast, bmac, 0, true); + + ipv4_set_cfg(&cfg_local); + + kprintf("Local IP: %i.%i.%i.%i",FORMAT_IP(cfg_local.ip)); + + g_t1_left_ms = cfg_local.rt->t1 * 1000; + g_t2_left_ms = cfg_local.rt->t2 * 1000; +} diff --git a/shared/net/application_layer/dhcp_daemon.h b/shared/net/application_layer/dhcp_daemon.h new file mode 100644 index 00000000..fcf501dc --- /dev/null +++ b/shared/net/application_layer/dhcp_daemon.h @@ -0,0 +1,19 @@ +#pragma once +#include "types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +void dhcp_daemon_entry(); +uint16_t get_dhcp_pid(); +bool dhcp_is_running(); +void dhcp_set_pid(uint16_t pid); + +void dhcp_notify_link_up(); +void dhcp_notify_link_down(); +void dhcp_force_renew(); + +#ifdef __cplusplus +} +#endif diff --git a/shared/net/application_layer/http.c b/shared/net/application_layer/http.c new file mode 100644 index 00000000..602d81b9 --- /dev/null +++ b/shared/net/application_layer/http.c @@ -0,0 +1,258 @@ +#include "http.h" +#include "std/string.h" +#include "std/memfunctions.h" +extern uintptr_t malloc(uint64_t size); +extern void free(void *ptr, uint64_t size); +extern void sleep(uint64_t ms); + +static inline bool is_space(char c) { + return c == ' ' || c == '\t'; +} +static inline bool starts_with(const char *a, const char *b, uint32_t len) { + for (uint32_t i = 0; i < len; i++) + if (a[i] != b[i]) return false; + return true; +} +static inline uint32_t parse_u32(const char *s, uint32_t len) { + uint32_t r = 0; + for (uint32_t i = 0; i < len; i++) { + char c = s[i]; + if (c >= '0' && c <= '9') { + r = r * 10 + (uint32_t)(c - '0'); + } else { + break; + } + } + return r; +} + +string http_header_builder(const HTTPHeadersCommon *C, + const HTTPHeader *H, uint32_t N) +{ + string out = string_repeat('\0', 0); + + if (C->type.length) { + string_append_bytes(&out, "Content-Type: ", 14); + string_append_bytes(&out, + C->type.data, + C->type.length); + string_append_bytes(&out, "\r\n", 2); + } + + if (C->length) { + string tmp = string_format("Content-Length: %i\r\n", + (int)C->length); + string_append_bytes(&out, tmp.data, tmp.length); + free(tmp.data, tmp.mem_length); + } + + if (C->date.length) { + string_append_bytes(&out, "Date: ", 6); + string_append_bytes(&out, + C->date.data, + C->date.length); + string_append_bytes(&out, "\r\n", 2); + } + + if (C->host.length) { + string_append_bytes(&out, "Host: ", 6); + string_append_bytes(&out, + C->host.data, + C->host.length); + string_append_bytes(&out, "\r\n", 2); + } else { + string_append_bytes(&out, "Host: RedactedOS_0.1\r\n", 22); + } + + if (C->connection.length) { + string_append_bytes(&out, "Connection: ", 12); + string_append_bytes(&out, + C->connection.data, + C->connection.length); + string_append_bytes(&out, "\r\n", 2); + } + + if (C->keep_alive.length) { + string_append_bytes(&out, "Keep-Alive: ", 12); + string_append_bytes(&out, + C->keep_alive.data, + C->keep_alive.length); + string_append_bytes(&out, "\r\n", 2); + } + + for (uint32_t i = 0; i < N; i++) { + const HTTPHeader *hdr = &H[i]; + string_append_bytes(&out, + hdr->key.data, + hdr->key.length); + string_append_bytes(&out, ": ", 2); + string_append_bytes(&out, + hdr->value.data, + hdr->value.length); + string_append_bytes(&out, "\r\n", 2); + } + + string_append_bytes(&out, "\r\n", 2); + return out; +} + + +void http_header_parser(const char *buf, uint32_t len, + HTTPHeadersCommon *C, + HTTPHeader **out_extra, + uint32_t *out_extra_count) +{ + *C = (HTTPHeadersCommon){0}; + + uint32_t max_lines = 0; + for (uint32_t i = 0; i + 1 < len; i++) { + if (buf[i]=='\r' && buf[i+1]=='\n') + max_lines++; + } + + HTTPHeader *extras = (HTTPHeader*)(uintptr_t)malloc(sizeof(*extras) * max_lines); + if (!extras) { + *out_extra = NULL; + *out_extra_count = 0; + return; + } + uint32_t extra_i = 0; + uint32_t pos = 0; + + char key_tmp[64]; + + while (pos + 1 < len) { + uint32_t eol = pos; + while (eol + 1 < len && !(buf[eol]=='\r' && buf[eol+1]=='\n')) + eol++; + if (eol == pos) { + pos += 2; + break; + } + + uint32_t sep = pos; + while (sep < eol && buf[sep] != ':') sep++; + uint32_t key_len = sep - pos; + uint32_t val_start = sep + 1; + while (val_start < eol && is_space((unsigned char)buf[val_start])) + val_start++; + uint32_t val_len = eol - val_start; + + uint32_t copy_len = (key_len < sizeof(key_tmp)-1) ? key_len : (sizeof(key_tmp)-1); + for (uint32_t i = 0; i < copy_len; i++) { + key_tmp[i] = buf[pos + i]; + } + key_tmp[copy_len] = '\0'; + + if (copy_len == 14 && strcmp(key_tmp, "content-length", true) == 0) { + C->length = parse_u32(buf + val_start, val_len); + } + else if (copy_len == 12 && strcmp(key_tmp, "content-type", true) == 0) { + C->type = string_ca_max((char*)(buf + val_start), val_len); + } + else if (copy_len == 4 && strcmp(key_tmp, "date", true) == 0) { + C->date = string_ca_max((char*)(buf + val_start), val_len); + } + else if (copy_len == 10 && strcmp(key_tmp, "connection", true) == 0) { + C->connection = string_ca_max((char*)(buf + val_start), val_len); + } + else if (copy_len == 10 && strcmp(key_tmp, "keep-alive", true) == 0) { + C->keep_alive = string_ca_max((char*)(buf + val_start), val_len); + } + else { + string key = string_ca_max((char*)(buf + pos), key_len); + string value = string_ca_max((char*)(buf + val_start), val_len); + extras[extra_i++] = (HTTPHeader){ key, value }; + } + + pos = eol + 2; + } + + *out_extra = extras; + *out_extra_count = extra_i; +} + +string http_request_builder(const HTTPRequestMsg *R) +{ + static const char *Mnames[] = { "GET", "POST", "PUT", "DELETE" }; + string out = string_format("%s ", Mnames[R->method]); + + string_append_bytes(&out, R->path.data, R->path.length); + + string_append_bytes(&out, " HTTP/1.1\r\n", 11); + + string hdrs = http_header_builder( + &R->headers_common, + R->extra_headers, R->extra_header_count); + string_append_bytes(&out, hdrs.data, hdrs.length); + free(hdrs.data, hdrs.mem_length); + + if (R->body.ptr && R->body.size) { + string body = string_ca_max((char*)R->body.ptr, R->body.size); + string_append_bytes(&out, body.data, body.length); + free(body.data, body.mem_length); + } + + return out; +} + +string http_response_builder(const HTTPResponseMsg *R) { + string out = string_format("HTTP/1.1 %i ", (int)R->status_code); + string_append_bytes(&out, + R->reason.data, + R->reason.length); + string_append_bytes(&out, "\r\n", 2); + + string hdrs = http_header_builder( + &R->headers_common, + R->extra_headers, + R->extra_header_count + ); + string_append_bytes(&out, hdrs.data, hdrs.length); + free(hdrs.data, hdrs.mem_length); + + if (R->body.ptr && R->body.size) { + string_append_bytes(&out, + (char*)R->body.ptr, + (uint32_t)R->body.size); + } + return out; +} + + +int find_crlfcrlf(const char *data, uint32_t len) { + for (uint32_t i = 0; i + 3 < len; i++) { + if (data[i] == '\r' && + data[i+1] == '\n' && + data[i+2] == '\r' && + data[i+3] == '\n') + { + return (int)i; + } + } + return -1; +} + +sizedptr http_get_payload(sizedptr header) { + if (!header.ptr || header.size < 4) { + return (sizedptr){0}; + } + int start = find_crlfcrlf((char*)header.ptr, header.size); + if (start < 0) { + return (sizedptr){0}; + } + return (sizedptr){ + header.ptr + (uint32_t)(start + 4), + header.size - (uint32_t)(start + 4) + }; +} + +string http_get_chunked_payload(sizedptr chunk) { + if (chunk.ptr && chunk.size > 0) { + int sizetrm = strindex((char*)chunk.ptr, "\r\n"); + uint64_t chunk_size = parse_hex_u64((char*)chunk.ptr, sizetrm); + return string_ca_max((char*)(chunk.ptr + sizetrm + 2), + (uint32_t)chunk_size); + } + return (string){0}; +} \ No newline at end of file diff --git a/shared/net/application_layer/http.h b/shared/net/application_layer/http.h new file mode 100644 index 00000000..d5b4a4eb --- /dev/null +++ b/shared/net/application_layer/http.h @@ -0,0 +1,82 @@ +#pragma once + +#include "std/string.h" +#include "std/memfunctions.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef enum { + HTTP_METHOD_GET, + HTTP_METHOD_POST, + HTTP_METHOD_PUT, + HTTP_METHOD_DELETE +} HTTPMethod; + +typedef enum { + HTTP_OK = 200, + HTTP_BAD_REQUEST = 400, + HTTP_UNAUTHORIZED = 401, + HTTP_FORBIDDEN = 403, + HTTP_NOT_FOUND = 404, + HTTP_INTERNAL_SERVER_ERROR = 500, + HTTP_NOT_IMPLEMENTED = 501, + HTTP_SERVICE_UNAVAILABLE = 503, + HTTP_DEBUG = 800, +} HttpError; + +typedef struct { + string key; + string value; +} HTTPHeader; + +typedef struct { + uint32_t length; + string type; + string date; + string connection; + string keep_alive; + string host; +} HTTPHeadersCommon; + +typedef struct { + HTTPMethod method; + string path; + HTTPHeadersCommon headers_common; + HTTPHeader *extra_headers; + uint32_t extra_header_count; + sizedptr body; +} HTTPRequestMsg; + +typedef struct { + HttpError status_code; + string reason; + HTTPHeadersCommon headers_common; + HTTPHeader *extra_headers; + uint32_t extra_header_count; + sizedptr body; +} HTTPResponseMsg; + +string http_header_builder(const HTTPHeadersCommon *common, + const HTTPHeader *extra, + uint32_t extra_count); + +void http_header_parser(const char *buf, uint32_t len, + HTTPHeadersCommon *out_common, + HTTPHeader **out_extra, + uint32_t *out_extra_count); + +string http_request_builder(const HTTPRequestMsg *req); + +string http_response_builder(const HTTPResponseMsg *res); + +int find_crlfcrlf(const char *data, uint32_t len); + +sizedptr http_get_payload(sizedptr header); + +string http_get_chunked_payload(sizedptr chunk); + +#ifdef __cplusplus +} +#endif diff --git a/shared/net/application_layer/socket_http_client.hpp b/shared/net/application_layer/socket_http_client.hpp new file mode 100644 index 00000000..728ee4c0 --- /dev/null +++ b/shared/net/application_layer/socket_http_client.hpp @@ -0,0 +1,141 @@ +#pragma once +#include "console/kio.h" +#include "net/transport_layer/socket_tcp.hpp" +#include "http.h" +#include "std/string.h" +#include "std/memfunctions.h" +#define KP(fmt, ...) \ + do { kprintf(fmt, ##__VA_ARGS__); } while (0) + +class HTTPClient { +private: + TCPSocket sock; + +public: + explicit HTTPClient(uint16_t pid); + ~HTTPClient(); + int32_t connect(uint32_t ip, uint16_t port); + HTTPResponseMsg send_request(const HTTPRequestMsg &req); + int32_t close(); +}; + +HTTPClient::HTTPClient(uint16_t pid) + : sock(SOCK_ROLE_CLIENT, pid) +{} + +HTTPClient::~HTTPClient() { + sock.close(); +} + +int32_t HTTPClient::connect(uint32_t ip, uint16_t port) { + return sock.connect(ip, port); +} + +HTTPResponseMsg HTTPClient::send_request(const HTTPRequestMsg &req) { + string out = http_request_builder(&req); + int64_t sent = sock.send(out.data, out.length); + free(out.data, out.mem_length); + + HTTPResponseMsg resp{}; + if (sent < 0) { + resp.status_code = (HttpError)sent; + return resp; + } + + string buf = string_repeat('\0', 0); + char tmp[512]; + int attempts = 0; + int hdr_end = -1; + while (true) { + int64_t r = sock.recv(tmp, sizeof(tmp)); + if (r < 0) { + free(buf.data, buf.mem_length); + resp.status_code = (HttpError)SOCK_ERR_SYS; + return resp; + } + if (r > 0) { + string_append_bytes(&buf, tmp, (uint32_t)r); + } + hdr_end = find_crlfcrlf(buf.data, buf.length); + if (hdr_end >= 0) break; + if (++attempts > 50) { + free(buf.data, buf.mem_length); + resp.status_code = (HttpError)SOCK_ERR_PROTO; + return resp; + } + sleep(10); + } + + { + uint32_t i = 0; + while (i < (uint32_t)hdr_end && buf.data[i] != ' ') i++; + uint32_t code = 0, j = i+1; + while (j < (uint32_t)hdr_end && buf.data[j] >= '0' && buf.data[j] <= '9') { + code = code*10 + (buf.data[j]-'0'); ++j; + } + resp.status_code = (HttpError)code; + while (j < (uint32_t)hdr_end && buf.data[j]==' ') ++j; + if (j < (uint32_t)hdr_end) { + uint32_t rlen = hdr_end - j; + resp.reason = string_repeat('\0', 0); + string_append_bytes(&resp.reason, buf.data+j, rlen); + } + } + + HTTPHeader *extras = nullptr; + uint32_t extra_count = 0; + int status_line_end = strindex((char*)buf.data, "\r\n"); + http_header_parser( + (char*)buf.data + status_line_end + 2, + buf.length - (uint32_t)(status_line_end + 2), + &resp.headers_common, + &extras, + &extra_count); + resp.extra_headers = extras; + resp.extra_header_count = extra_count; + + uint32_t body_start = hdr_end + 4; + uint32_t have = (buf.length > body_start) + ? buf.length - body_start + : 0; + + uint32_t need = resp.headers_common.length; + if (need > 0) { + while (have < need) { + int64_t r = sock.recv(tmp, sizeof(tmp)); + if (r <= 0) break; + string_append_bytes(&buf, tmp, (uint32_t)r); + have += (uint32_t)r; + } + } else { + int idle = 0; + while (idle < 5) { + int64_t r = sock.recv(tmp, sizeof(tmp)); + if (r > 0) { + string_append_bytes(&buf, tmp, (uint32_t)r); + have += (uint32_t)r; + idle = 0; + } else { + ++idle; + sleep(20); + } + } + } + if (have > 0) { + char *body_copy = (char*)malloc(have + 1); + if (body_copy) { + memcpy(body_copy, + buf.data + body_start, + have); + body_copy[have] = '\0'; + resp.body.ptr = (uintptr_t)body_copy; + resp.body.size = have; + } + } + free(buf.data, buf.mem_length); + return resp; +} + +int32_t HTTPClient::close() { + return sock.close(); +} diff --git a/shared/net/application_layer/socket_http_server.hpp b/shared/net/application_layer/socket_http_server.hpp new file mode 100644 index 00000000..80d495e7 --- /dev/null +++ b/shared/net/application_layer/socket_http_server.hpp @@ -0,0 +1,125 @@ +#pragma once + +#include "console/kio.h" +#include "net/transport_layer/socket_tcp.hpp" +#include "http.h" +#include "std/string.h" +#include "std/memfunctions.h" + +#define KP(fmt, ...) \ + do { kprintf(fmt, ##__VA_ARGS__); } while (0) + +class HTTPServer { +private: + TCPSocket sock; + +public: + explicit HTTPServer(uint16_t pid) : sock(SOCK_ROLE_SERVER, pid) {} + + ~HTTPServer() { close(); } + + int32_t bind(uint16_t port) { return sock.bind(port); } + int32_t listen(int backlog = 4) { return sock.listen(backlog); } + TCPSocket* accept() { return sock.accept(); } + + HTTPRequestMsg recv_request(TCPSocket* client) { + HTTPRequestMsg req{}; + if (!client) return req; + + string buf = string_repeat('\0', 0); + char tmp[512]; + int attempts = 0; + int hdr_end = -1; + + while (true) { + int64_t r = client->recv(tmp, sizeof(tmp)); + if (r < 0) return req; + if (r > 0) string_append_bytes(&buf, tmp, (uint32_t)r); + hdr_end = find_crlfcrlf(buf.data, buf.length); + if (hdr_end >= 0) break; + if (++attempts > 100) return req; + sleep(10); + } + + uint32_t i = 0; + while (i < (uint32_t)hdr_end && buf.data[i] != ' ') ++i; + string method_tok = string_repeat('\0', 0); + string_append_bytes(&method_tok, buf.data, i); + + if (method_tok.length == 3 && memcmp(method_tok.data, "GET", 3) == 0) + req.method = HTTP_METHOD_GET; + else if (method_tok.length == 4 && memcmp(method_tok.data, "POST", 4) == 0) + req.method = HTTP_METHOD_POST; + else if (method_tok.length == 3 && memcmp(method_tok.data, "PUT", 3) == 0) + req.method = HTTP_METHOD_PUT; + else if (method_tok.length == 6 && memcmp(method_tok.data, "DELETE", 6) == 0) + req.method = HTTP_METHOD_DELETE; + else + req.method = HTTP_METHOD_GET; + + uint32_t j = i + 1; + uint32_t path_start = j; + while (j < (uint32_t)hdr_end && buf.data[j] != ' ') ++j; + req.path = string_repeat('\0', 0); + string_append_bytes(&req.path, buf.data + path_start, j - path_start); + + int status_line_end = strindex((char*)buf.data, "\r\n"); + http_header_parser( + (char*)buf.data + status_line_end + 2, + buf.length - (uint32_t)(status_line_end + 2), + &req.headers_common, + &req.extra_headers, + &req.extra_header_count + ); + + uint32_t body_start = hdr_end + 4; + uint32_t have = buf.length > body_start ? buf.length - body_start : 0; + uint32_t need = req.headers_common.length; + + if (need > 0) { + while (have < need) { + int64_t r = client->recv(tmp, sizeof(tmp)); + if (r <= 0) break; + string_append_bytes(&buf, tmp, (uint32_t)r); + have += (uint32_t)r; + } + } else { + int idle = 0; + while (idle < 5) { + int64_t r = client->recv(tmp, sizeof(tmp)); + if (r > 0) { + string_append_bytes(&buf, tmp, (uint32_t)r); + have += (uint32_t)r; + idle = 0; + } else { + ++idle; + sleep(20); + } + } + } + + if (have > 0) { + char* body_copy = (char*)malloc(have + 1); + if (body_copy) { + memcpy(body_copy, buf.data + body_start, have); + body_copy[have] = '\0'; + req.body.ptr = (uintptr_t)body_copy; + req.body.size = have; + } + } + + free(buf.data, buf.mem_length); + return req; + } + + + int32_t send_response(TCPSocket* client, const HTTPResponseMsg& res) { + if (!client) return SOCK_ERR_STATE; + string out = http_response_builder(&res); + int64_t sent = client->send(out.data, out.length); + free(out.data, out.mem_length); + return sent < 0 ? (int32_t)sent : SOCK_OK; + } + + int32_t close() { return sock.close(); } +}; diff --git a/shared/net/arp.c b/shared/net/arp.c deleted file mode 100644 index 403f226b..00000000 --- a/shared/net/arp.c +++ /dev/null @@ -1,28 +0,0 @@ -#include "arp.h" -#include "eth.h" -#include "std/memfunctions.h" - -void create_arp_packet(uintptr_t p, uint8_t* src_mac, uint32_t src_ip, uint8_t* dst_mac, uint32_t dst_ip, bool is_request){ - p = create_eth_packet(p, src_mac, is_request ? (uint8_t[]){0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF} : dst_mac, 0x806); - - arp_hdr_t* arp = (arp_hdr_t*)p; - - arp->htype = __builtin_bswap16(1); - arp->ptype = __builtin_bswap16(0x0800); - arp->hlen = 6; - arp->plen = 4; - arp->opcode = __builtin_bswap16(is_request ? 1 : 2); - memcpy(arp->sender_mac, src_mac, 6); - arp->sender_ip = __builtin_bswap32(src_ip); - memcpy(arp->target_mac, dst_mac, 6); - arp->target_ip = __builtin_bswap32(dst_ip); -} - -void arp_populate_response(network_connection_ctx *ctx, arp_hdr_t* arp){ - memcpy(ctx->mac, arp->sender_mac, 6); - ctx->ip = arp->sender_ip; -} - -bool arp_should_handle(arp_hdr_t *arp, uint32_t ip){ - return __builtin_bswap32(arp->target_ip) == ip; -} \ No newline at end of file diff --git a/shared/net/arp.h b/shared/net/arp.h deleted file mode 100644 index d29e1144..00000000 --- a/shared/net/arp.h +++ /dev/null @@ -1,28 +0,0 @@ -#pragma once - -#ifdef __cplusplus -extern "C" { -#endif - -#include "types.h" -#include "net/network_types.h" - -typedef struct __attribute__((packed)) arp_hdr_t { - uint16_t htype; - uint16_t ptype; - uint8_t hlen; - uint8_t plen; - uint16_t opcode; - uint8_t sender_mac[6]; - uint32_t sender_ip; - uint8_t target_mac[6]; - uint32_t target_ip; -} arp_hdr_t; - -void create_arp_packet(uintptr_t p, uint8_t* src_mac, uint32_t src_ip, uint8_t* dst_mac, uint32_t dst_ip, bool is_request); -bool arp_should_handle(arp_hdr_t *arp, uint32_t ip); -void arp_populate_response(network_connection_ctx *ctx, arp_hdr_t* arp); - -#ifdef __cplusplus -} -#endif \ No newline at end of file diff --git a/shared/net/checksums.c b/shared/net/checksums.c index 32b11ee0..6eb0dfa0 100644 --- a/shared/net/checksums.c +++ b/shared/net/checksums.c @@ -1,25 +1,24 @@ -#include "network_types.h" +#include "checksums.h" -uint16_t checksum16(uint16_t *data, size_t len) { +uint16_t checksum16(const uint16_t *data, size_t len) { uint32_t sum = 0; - for (int i = 0; i < len; i++) sum += data[i]; + for (size_t i = 0; i < len; i++) sum += data[i]; while (sum >> 16) sum = (sum & 0xFFFF) + (sum >> 16); - return ~sum; + return (uint16_t)~sum; } -uint16_t checksum16_pipv4( - uint32_t src_ip, - uint32_t dst_ip, - uint8_t protocol, - const uint8_t* payload, - uint16_t length -) { +uint16_t checksum16_pipv4(uint32_t src_ip, + uint32_t dst_ip, + uint8_t protocol, + const uint8_t *payload, + uint16_t length) +{ uint32_t sum = 0; sum += (src_ip >> 16) & 0xFFFF; - sum += src_ip & 0xFFFF; + sum += src_ip & 0xFFFF; sum += (dst_ip >> 16) & 0xFFFF; - sum += dst_ip & 0xFFFF; + sum += dst_ip & 0xFFFF; sum += protocol; sum += length; @@ -32,5 +31,5 @@ uint16_t checksum16_pipv4( while (sum >> 16) sum = (sum & 0xFFFF) + (sum >> 16); - return ~sum; -} \ No newline at end of file + return (uint16_t)~sum; +} diff --git a/shared/net/checksums.h b/shared/net/checksums.h new file mode 100644 index 00000000..0e0b777e --- /dev/null +++ b/shared/net/checksums.h @@ -0,0 +1,17 @@ +#pragma once +#include "types.h" + +#ifdef __cplusplus +extern "C" { +#endif +uint16_t checksum16(const uint16_t *data, size_t len); + +uint16_t checksum16_pipv4(uint32_t src_ip, + uint32_t dst_ip, + uint8_t protocol, + const uint8_t *payload, + uint16_t length); + +#ifdef __cplusplus +} +#endif diff --git a/shared/net/dhcp.c b/shared/net/dhcp.c deleted file mode 100644 index cf21c0a0..00000000 --- a/shared/net/dhcp.c +++ /dev/null @@ -1,72 +0,0 @@ -#include "dhcp.h" -#include "std/memfunctions.h" -#include "math/rng.h" - -void create_dhcp_packet(uintptr_t p, dhcp_request *payload){ - network_connection_ctx source = (network_connection_ctx){ - .port = 68, - }; - network_connection_ctx destination = (network_connection_ctx){ - .ip = (255 << 24) | (255 << 16) | (255 << 8) | 255, - .mac = {0xFF, 0xFF, 0xFF, 0xFF, 0xFF, 0xFF}, - .port = 67, - }; - //TODO: use a syscall for the rng - rng_t rng; - rng_init_random(&rng); - dhcp_packet packet = (dhcp_packet){ - .op = 1,//request - .htype = 1,//Ethernet - .hlen = 6,//Mac length - .hops = 0, - .xid = rng_next32(&rng),//Transaction ID - .secs = 0, - .flags = __builtin_bswap16(0x8000),//Broadcast - .ciaddr = 0, - .yiaddr = 0, - .siaddr = 0, - .giaddr = 0, - }; - memcpy(packet.chaddr, payload->mac, 6); - memcpy(source.mac, payload->mac, 6); - - packet.options[0] = 0x63; // magic - packet.options[1] = 0x82; - packet.options[2] = 0x53; - packet.options[3] = 0x63; // magic - - packet.options[4] = 53; // DHCP type - packet.options[5] = 1; // length - if (payload->server_ip != 0 && payload->offered_ip != 0){ - packet.options[6] = 3; // DHCPREQUEST - - packet.options[7] = 50; - packet.options[8] = 4; - memcpy(&packet.options[9], &payload->offered_ip, 4); - - packet.options[13] = 54; - packet.options[14] = 4; - memcpy(&packet.options[15], &payload->server_ip, 4); - packet.options[19] = 255; - } else { - packet.options[6] = 1; // DHCPDISCOVER - - packet.options[7] = 255; // END - } - - create_udp_packet(p, source, destination, (sizedptr){(uintptr_t)&packet, sizeof(dhcp_packet)}); -} - -dhcp_packet* dhcp_parse_packet_payload(uintptr_t ptr){ - sizedptr sptr = udp_parse_packet_payload(ptr); - return (dhcp_packet*)sptr.ptr; -} - -uint16_t dhcp_parse_option(dhcp_packet *pack, uint16_t option){ - for (int i = 0; i < 312; i++) - if (pack->options[i] == option) return i; - - return 0; -} - - diff --git a/shared/net/dhcp.h b/shared/net/dhcp.h deleted file mode 100644 index 9035e5b8..00000000 --- a/shared/net/dhcp.h +++ /dev/null @@ -1,43 +0,0 @@ -#pragma once - -#ifdef __cplusplus -extern "C" { -#endif - -#include "types.h" -#include "net/network_types.h" -#include "net/udp.h" - -typedef struct __attribute__((packed)) dhcp_packet { - uint8_t op; - uint8_t htype; - uint8_t hlen; - uint8_t hops; - uint32_t xid; - uint16_t secs; - uint16_t flags; - uint32_t ciaddr; - uint32_t yiaddr; - uint32_t siaddr; - uint32_t giaddr; - uint8_t chaddr[16]; - uint8_t sname[64]; - uint8_t file[128]; - uint8_t options[312]; -} dhcp_packet; - -typedef struct dhcp_request { - uint8_t mac[6]; - uint32_t server_ip; - uint32_t offered_ip; -} dhcp_request; - -#define DHCP_SIZE sizeof(eth_hdr_t) + sizeof(ipv4_hdr_t) + sizeof(udp_hdr_t) + sizeof(dhcp_packet) - -void create_dhcp_packet(uintptr_t p, dhcp_request *data); -dhcp_packet* dhcp_parse_packet_payload(uintptr_t ptr); -uint16_t dhcp_parse_option(dhcp_packet *pack, uint16_t option); - -#ifdef __cplusplus -} -#endif \ No newline at end of file diff --git a/shared/net/eth.c b/shared/net/eth.c deleted file mode 100644 index e1933d75..00000000 --- a/shared/net/eth.c +++ /dev/null @@ -1,23 +0,0 @@ -#include "eth.h" -#include "std/memfunctions.h" - -uintptr_t create_eth_packet(uintptr_t p, uint8_t src_mac[6], uint8_t dst_mac[6], uint16_t type){ - eth_hdr_t* eth = (eth_hdr_t*)p; - memcpy(eth->src_mac, src_mac, 6); - memcpy(eth->dst_mac, dst_mac, 6); - eth->ethertype = __builtin_bswap16(type); - return p + sizeof(eth_hdr_t); -} - -uint16_t eth_parse_packet_type(uintptr_t ptr){ - eth_hdr_t* eth = (eth_hdr_t*)ptr; - - ptr += sizeof(eth_hdr_t); - - return __builtin_bswap16(eth->ethertype); -} - -uintptr_t eth_get_source(uintptr_t ptr){ - eth_hdr_t* eth = (eth_hdr_t*)ptr; - return (uintptr_t)ð->src_mac; -} \ No newline at end of file diff --git a/shared/net/eth.h b/shared/net/eth.h deleted file mode 100644 index d5333391..00000000 --- a/shared/net/eth.h +++ /dev/null @@ -1,22 +0,0 @@ -#pragma once - -#ifdef __cplusplus -extern "C" { -#endif - -#include "types.h" -#include "net/network_types.h" - -typedef struct __attribute__((packed)) eth_hdr_t { - uint8_t dst_mac[6]; - uint8_t src_mac[6]; - uint16_t ethertype; -} eth_hdr_t; - -uint16_t eth_parse_packet_type(uintptr_t ptr); -uintptr_t create_eth_packet(uintptr_t ptr, uint8_t src_mac[6], uint8_t dst_mac[6], uint16_t type); -uintptr_t eth_get_source(uintptr_t ptr); - -#ifdef __cplusplus -} -#endif \ No newline at end of file diff --git a/shared/net/http.c b/shared/net/http.c deleted file mode 100644 index cb27520f..00000000 --- a/shared/net/http.c +++ /dev/null @@ -1,114 +0,0 @@ -#include "http.h" -#include "std/string.h" -#include "tcp.h" -#include "syscalls/syscalls.h" -#include "ipv4.h" -#include "std/memfunctions.h" - -string make_http_request(HTTPRequest request, char *domain, char *agent){ - //TODO: request instead of hardcoded GET - return string_format("GET / HTTP/1.1\r\nHost: %s\r\nUser-Agent: %s\r\nAccept: */*\r\n\r\n",domain, agent); -} - -sizedptr http_data_transfer(network_connection_ctx *dest, sizedptr payload, uint16_t port, tcp_data *data, uint8_t retry, uint32_t orig_seq, uint32_t orig_ack){ - if (retry == 5){ - printf("Exceeded max number of retries"); - return (sizedptr){0}; - } - - data->sequence = orig_seq; - data->ack = orig_ack; - data->flags = (1 << PSH_F) | (1 << ACK_F); - - data->payload = payload; - - tcp_send(port, dest, data); - - data->flags = (1 << ACK_F); - - uint8_t resp; - do { - resp = tcp_check_response(data, 0); - if (resp == TCP_OK) - break; - if (resp == TCP_RESET)//We don't reset, we ignore irrelevant packets (or we could parse them tbh) - continue; - if (resp == TCP_RETRY) - return http_data_transfer(dest, payload, port, data, retry+1, orig_seq, orig_ack); - } while (1); - - data->flags = (1 << PSH_F) | (1 << ACK_F); - - sizedptr http_content; - - resp = tcp_check_response(data, &http_content); - if (resp == TCP_RETRY){ - sleep(1000); - return http_data_transfer(dest, payload, port, data, retry+1, orig_seq, orig_ack); - } else if (resp == TCP_RESET){ - tcp_reset(port, dest, data); - return (sizedptr){0}; - } - - data->payload = (sizedptr){0}; - - data->flags = (1 << ACK_F); - tcp_send(port, dest, data); - - return http_content; -} - -sizedptr request_http_data(HTTPRequest request, network_connection_ctx *dest, uint16_t port){ - tcp_data data = (tcp_data){ - .window = UINT16_MAX, - }; - - printf("TCP Handshake"); - - if (!tcp_handskake(dest, 8888, &data, 0)){ - printf("TCP Handshake Error"); - return (sizedptr){0}; - } - - string serverstr = ipv4_to_string(dest->ip); - string req = make_http_request(request, serverstr.data, "redactedos/0.0.1"); - - free(serverstr.data, serverstr.mem_length); - - printf("HTTP Request"); - - //TODO: more chunked support - - sizedptr http_response = http_data_transfer(dest, (sizedptr){(uintptr_t)req.data, req.length}, port, &data, 0, data.sequence, data.ack); - - printf("TCP End"); - - free(req.data, req.mem_length); - - if (!tcp_close(dest, 8888, &data, 0, data.sequence, data.ack)){ - printf("TCP Connnection not closed"); - return (sizedptr){0}; - } - - return http_response; -} - -sizedptr http_get_payload(sizedptr header){ - if (header.ptr && header.size > 0){ - int start = strindex((char*)header.ptr, "\r\n\r\n"); - if (start < header.size){ - return (sizedptr){header.ptr + start + 4,header.size-start-4}; - } - } - return (sizedptr){0,0}; -} - -string http_get_chunked_payload(sizedptr chunk){ - //TODO: allow finding 0 to know when we're done reading the payload - if (chunk.ptr && chunk.size > 0){ - int sizetrm = strindex((char*)chunk.ptr, "\r\n"); - uint64_t chunk_size = parse_hex_u64((char*)chunk.ptr,sizetrm); - return string_ca_max((char*)(chunk.ptr + sizetrm + 2),chunk_size); - } - return (string){0}; -} \ No newline at end of file diff --git a/shared/net/http.h b/shared/net/http.h deleted file mode 100644 index fd94bdab..00000000 --- a/shared/net/http.h +++ /dev/null @@ -1,17 +0,0 @@ -#pragma once - -#include "tcp.h" -#include "types.h" -#include "network_types.h" -#include "std/string.h" - -typedef enum HTTPRequest { - GET, - POST, - PUT, - DELETE -} HTTPRequest; - -sizedptr request_http_data(HTTPRequest request, network_connection_ctx *dest, uint16_t port); -sizedptr http_get_payload(sizedptr header); -string http_get_chunked_payload(sizedptr chunk); \ No newline at end of file diff --git a/shared/net/icmp.c b/shared/net/icmp.c deleted file mode 100644 index cd79dc87..00000000 --- a/shared/net/icmp.c +++ /dev/null @@ -1,31 +0,0 @@ -#include "icmp.h" -#include "net/udp.h" -#include "net/eth.h" -#include "net/ipv4.h" -#include "std/memfunctions.h" - -void create_icmp_packet(uintptr_t p, network_connection_ctx source, network_connection_ctx destination, icmp_data* data){ - p = create_eth_packet(p, source.mac, destination.mac, 0x800); - - p = create_ipv4_packet(p, sizeof(icmp_packet), 0x01, source.ip, destination.ip); - - icmp_packet *packet = (icmp_packet*)p; - - packet->type = __builtin_bswap16(data->response ? 0 : 8); - packet->seq = __builtin_bswap16(data->seq); - packet->id = __builtin_bswap16(data->id); - memcpy(packet->payload, data->payload, 56); - packet->checksum = checksum16((uint16_t*)packet, sizeof(icmp_packet)); -} - -uint16_t icmp_get_sequence(icmp_packet *packet){ - return __builtin_bswap16(packet->seq); -} - -uint16_t icmp_get_id(icmp_packet *packet){ - return __builtin_bswap16(packet->id); -} - -void icmp_copy_payload(void* dest, icmp_packet *packet){ - memcpy(dest, packet->payload, 56); -} \ No newline at end of file diff --git a/shared/net/icmp.h b/shared/net/icmp.h deleted file mode 100644 index 59dd4410..00000000 --- a/shared/net/icmp.h +++ /dev/null @@ -1,33 +0,0 @@ -#pragma once - -#ifdef __cplusplus -extern "C" { -#endif - -#include "types.h" -#include "net/network_types.h" - -typedef struct __attribute__((packed)) icmp_packet { - uint8_t type; - uint8_t code; - uint16_t checksum; - uint16_t id; - uint16_t seq; - uint8_t payload[56]; -} icmp_packet; - -typedef struct icmp_data { - uint8_t response; - uint16_t seq; - uint16_t id; - uint8_t payload[56]; -} icmp_data; - -void create_icmp_packet(uintptr_t p, network_connection_ctx source, network_connection_ctx destination, icmp_data *data); -uint16_t icmp_get_sequence(icmp_packet *packet); -uint16_t icmp_get_id(icmp_packet *packet); -void icmp_copy_payload(void* dest, icmp_packet *packet); - -#ifdef __cplusplus -} -#endif \ No newline at end of file diff --git a/shared/net/internet_layer/icmp.c b/shared/net/internet_layer/icmp.c new file mode 100644 index 00000000..11aa79d9 --- /dev/null +++ b/shared/net/internet_layer/icmp.c @@ -0,0 +1,151 @@ +#include "icmp.h" +#include "net/internet_layer/ipv4.h" +#include "net/network_types.h" +#include "net/checksums.h" +#include "std/memfunctions.h" +#include "console/kio.h" +#include "ipv4.h" +#include "networking/network.h" + +extern uintptr_t malloc(uint64_t size); +extern void free(void *ptr, uint64_t size); +extern void sleep(uint64_t ms); + +#define MAX_PENDING 16 +#define POLL_MS 1 + +typedef struct { + bool in_use; + uint16_t id, seq; + bool received; +} ping_slot_t; + +static ping_slot_t g_pending[MAX_PENDING] = {0}; + +static int alloc_slot(uint16_t id, uint16_t seq){ + for(int i=0;i=0&&itype = d->response ? ICMP_ECHO_REPLY : ICMP_ECHO_REQUEST; + pkt->code = 0; + pkt->id = __builtin_bswap16(d->id); + pkt->seq = __builtin_bswap16(d->seq); + + if (d->payload) + memcpy(pkt->payload, d->payload, 56); + else + memset(pkt->payload, 0, 56); + + pkt->checksum = 0; +} + +void icmp_send_echo(uint32_t dst_ip, + uint16_t id, + uint16_t seq, + const uint8_t payload[56]) +{ + uint32_t pay_len = payload ? 56 : 32; + icmp_data d = { .response=false, .id=id, .seq=seq }; + if (payload) memcpy(d.payload, payload, 56); + else memset(d.payload, 0, 56); + + uint32_t icmp_len = 8 + pay_len; + uintptr_t buf = (uintptr_t)malloc(icmp_len); + if(!buf) return; + + const net_l2l3_endpoint *local = network_get_local_endpoint(); + create_icmp_packet(buf, local, NULL, &d); + + ((icmp_packet*)buf)->checksum = checksum16((uint16_t*)buf, icmp_len); + + ipv4_send_segment(local->ip, dst_ip, 1, (sizedptr){ buf, icmp_len }); + + free((void*)buf, icmp_len); +} + +void icmp_input(uintptr_t ptr, + uint32_t len, + uint32_t src_ip, + uint32_t dst_ip) +{ + if(len < 8) return; + + icmp_packet *pkt = (icmp_packet*)ptr; + uint16_t recv_ck = pkt->checksum; + pkt->checksum = 0; + if(checksum16((uint16_t*)pkt, len) != recv_ck) return; + pkt->checksum = recv_ck; + + uint8_t type = pkt->type; + uint16_t id = __builtin_bswap16(pkt->id); + uint16_t sq = __builtin_bswap16(pkt->seq); + uint32_t pay = len - 8; + if(pay > 56) pay = 56; + + if(type == ICMP_ECHO_REQUEST){ + icmp_data d = { .response=true, .id=id, .seq=sq }; + memcpy(d.payload, pkt->payload, pay); + memset(d.payload + pay, 0, 56 - pay); + + uint32_t reply_len = 8 + pay; + uintptr_t buf = (uintptr_t)malloc(reply_len); + if(!buf) return; + + const net_l2l3_endpoint *local = network_get_local_endpoint(); + create_icmp_packet(buf, local, NULL, &d); + ((icmp_packet*)buf)->checksum = checksum16((uint16_t*)buf, reply_len); + + ipv4_send_segment(local->ip, src_ip, 1, (sizedptr){ buf, reply_len }); + free((void*)buf, reply_len); + return; + } + + if(type == ICMP_ECHO_REPLY) + mark_received(id, sq); +} + +bool icmp_ping(uint32_t dst_ip, + uint16_t id, + uint16_t seq, + uint32_t timeout_ms) +{ + int slot = alloc_slot(id, seq); + if(slot < 0) return false; + + icmp_send_echo(dst_ip, id, seq, NULL); + + uint32_t waited = 0; + while(waited < timeout_ms){ + if(g_pending[slot].received){ + free_slot(slot); + return true; + } + sleep(POLL_MS); + waited += POLL_MS; + } + + free_slot(slot); + return false; +} diff --git a/shared/net/internet_layer/icmp.h b/shared/net/internet_layer/icmp.h new file mode 100644 index 00000000..387eb477 --- /dev/null +++ b/shared/net/internet_layer/icmp.h @@ -0,0 +1,50 @@ +#pragma once +#include "types.h" +#include "net/network_types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +#define ICMP_ECHO_REPLY 0 +#define ICMP_ECHO_REQUEST 8 + +typedef struct __attribute__((packed)) { + uint8_t type; + uint8_t code; + uint16_t checksum; + uint16_t id; + uint16_t seq; + uint8_t payload[56]; +} icmp_packet; + +typedef struct { + bool response; //1 replay 0 request + uint16_t id; + uint16_t seq; + uint8_t payload[56]; +} icmp_data; + +void create_icmp_packet(uintptr_t p, + const net_l2l3_endpoint *src, + const net_l2l3_endpoint *dst, + const icmp_data *data); + +void icmp_input(uintptr_t ptr, + uint32_t len, + uint32_t src_ip, + uint32_t dst_ip); + +void icmp_send_echo(uint32_t dst_ip, + uint16_t id, + uint16_t seq, + const uint8_t payload[56]); + +bool icmp_ping(uint32_t dst_ip, + uint16_t id, + uint16_t seq, + uint32_t timeout_ms); + +#ifdef __cplusplus +} +#endif diff --git a/shared/net/internet_layer/ipv4.c b/shared/net/internet_layer/ipv4.c new file mode 100644 index 00000000..6ee6545d --- /dev/null +++ b/shared/net/internet_layer/ipv4.c @@ -0,0 +1,181 @@ +#include "ipv4.h" +#include "console/kio.h" +#include "std/memfunctions.h" +#include "networking/network.h" +#include "net/link_layer/arp.h" +#include "net/transport_layer/udp.h" +#include "net/transport_layer/tcp.h" +#include "icmp.h" +#include "std/string.h" +#include "types.h" +#include "ipv4_route.h" + +extern uintptr_t malloc(uint64_t size); +extern void free(void *ptr, uint64_t size); +extern void sleep(uint64_t ms); + +static net_runtime_opts_t g_rt_opts; +net_cfg_t g_net_cfg = { + .ip = 0, + .mask = 0, + .gw = 0, + .mode = NET_MODE_DHCP, + .rt = &g_rt_opts +}; + +void ipv4_cfg_init() { + memset(&g_rt_opts, 0, sizeof(g_rt_opts)); + g_net_cfg.ip = 0; + g_net_cfg.mask = 0; + g_net_cfg.gw = 0; + g_net_cfg.mode = NET_MODE_DHCP; + g_net_cfg.rt = &g_rt_opts; + ipv4_rt_init(); +} + +void ipv4_set_cfg(const net_cfg_t *src) { + if (!src) return; + g_net_cfg.ip = src->ip; + g_net_cfg.mask = src->mask; + g_net_cfg.gw = src->gw; + g_net_cfg.mode = src->mode; + if (src->rt) { + g_rt_opts = *src->rt; + } else { + memset(&g_rt_opts, 0, sizeof(g_rt_opts)); + } + if(g_net_cfg.ip != 0){ + uint8_t bmac[6] = {0xFF,0xFF,0xFF,0xFF,0xFF,0xFF}; + arp_table_put(ipv4_broadcast(g_net_cfg.ip, g_net_cfg.mask), bmac, 0, true); + } + g_net_cfg.rt = &g_rt_opts; + ipv4_rt_init(); + if (g_net_cfg.gw) { + ipv4_rt_add(0, 0, g_net_cfg.gw); + } +} + +const net_cfg_t* ipv4_get_cfg() { + return &g_net_cfg; +} + +string ipv4_to_string(uint32_t ip) { + return string_format("%i.%i.%i.%i", + (ip>>24)&0xFF, + (ip>>16)&0xFF, + (ip>>8)&0xFF, + ip&0xFF); +} + +static uint16_t ipv4_checksum(const void *buf, size_t len) { + const uint16_t *data = buf; + uint32_t sum = 0; + for (; len > 1; len -= 2) { + sum += *data++; + } + if (len) { + sum += *(const uint8_t*)data; + } + while (sum >> 16) { + sum = (sum & 0xFFFF) + (sum >> 16); + } + return (uint16_t)~sum; +} + +void ip_input(uintptr_t ip_ptr, + uint32_t ip_len, + const uint8_t src_mac[6]) +{ + if (ip_len < sizeof(ipv4_hdr_t)) return; + ipv4_hdr_t *hdr = (ipv4_hdr_t*)ip_ptr; + uint8_t version = hdr->version_ihl >> 4; + uint8_t ihl = hdr->version_ihl & 0x0F; + if (version != 4 || ihl < 5) return; + + uint32_t header_bytes = ihl * 4; + if (ip_len < header_bytes) return; + + if (hdr->header_checksum != 0) { + uint16_t recv_ck = hdr->header_checksum; + hdr->header_checksum = 0; + if (ipv4_checksum(hdr, header_bytes) != recv_ck) return; + hdr->header_checksum = recv_ck; + } + + uint32_t sip = __builtin_bswap32(hdr->src_ip); + arp_table_put(sip, src_mac, 60000, false); + + uint32_t dip = __builtin_bswap32(hdr->dst_ip); + const net_cfg_t *cfg = ipv4_get_cfg(); //TODO manage special ip + + uintptr_t payload_ptr = ip_ptr + header_bytes; + uint32_t payload_len = __builtin_bswap16(hdr->total_length) - header_bytes; + switch (hdr->protocol) { + case 1://icmp + icmp_input(payload_ptr, payload_len, sip, dip); + break; + case 6://tcp + tcp_input(payload_ptr, payload_len, sip, dip); + break; + case 17://udp + udp_input(payload_ptr, payload_len, sip, dip); + break; + default: + //everything elese + break; + } +} + +void ipv4_send_segment(uint32_t src_ip, + uint32_t dst_ip, + uint8_t proto, + sizedptr segment) +{ + uint32_t nh_ip; + if (!ipv4_rt_lookup(dst_ip, &nh_ip)) { + const net_cfg_t *cfg = ipv4_get_cfg(); + if (cfg && ((dst_ip & cfg->mask) == (cfg->ip & cfg->mask))) { + nh_ip = dst_ip; + } else { + nh_ip = cfg ? cfg->gw : dst_ip; + } + } + + uint8_t dst_mac[6]; + bool ok = arp_resolve(nh_ip, dst_mac, 200); + if (!ok) { + memset(dst_mac, 0xFF, sizeof(dst_mac)); + } + + uint32_t total = sizeof(eth_hdr_t) + + sizeof(ipv4_hdr_t) + + segment.size; + uintptr_t buf = (uintptr_t)malloc(total); + if (!buf) return; + + const net_l2l3_endpoint *local = network_get_local_endpoint(); + uintptr_t ptr = create_eth_packet(buf, local->mac, dst_mac, 0x0800); + + ipv4_hdr_t *ip = (ipv4_hdr_t *)ptr; + ip->version_ihl = (4 << 4) | (sizeof(*ip)/4); + ip->dscp_ecn = 0; + ip->total_length = __builtin_bswap16(sizeof(*ip) + segment.size); + ip->identification = 0; + ip->flags_frag_offset = __builtin_bswap16(0x4000); + ip->ttl = 64; + ip->protocol = proto; + ip->src_ip = __builtin_bswap32(src_ip); + ip->dst_ip = __builtin_bswap32(dst_ip); + ip->header_checksum = 0; + ip->header_checksum = ipv4_checksum(ip, sizeof(*ip)); + + ptr += sizeof(*ip); + + if (segment.size) { + memcpy((void*)ptr, (void*)segment.ptr, segment.size); + } + + eth_send_frame(buf, total); + + free((void*)buf, total); +} diff --git a/shared/net/internet_layer/ipv4.h b/shared/net/internet_layer/ipv4.h new file mode 100644 index 00000000..40fefe2d --- /dev/null +++ b/shared/net/internet_layer/ipv4.h @@ -0,0 +1,71 @@ +#pragma once +#include "types.h" +#include "std/string.h" +#include "net/link_layer/eth.h" +#include "net/network_types.h" +#include "net/checksums.h" +#ifdef __cplusplus +extern "C" { +#endif +#define NET_MODE_DISABLED ((int8_t)-1) +#define NET_MODE_DHCP 0 +#define NET_MODE_STATIC 1 + +typedef struct net_runtime_opts { + uint16_t mtu; + uint32_t t1; + uint32_t t2; + uint32_t dns[2]; + uint32_t ntp[2]; + uint16_t xid; + uint32_t server_ip; + uint32_t lease; +} net_runtime_opts_t; + +typedef struct net_cfg { + uint32_t ip; + uint32_t mask; + uint32_t gw; + int8_t mode; + net_runtime_opts_t *rt; +} net_cfg_t; + +typedef struct __attribute__((packed)) ipv4_hdr_t { + uint8_t version_ihl; + uint8_t dscp_ecn; + uint16_t total_length; + uint16_t identification; + uint16_t flags_frag_offset; + uint8_t ttl; + uint8_t protocol; + uint16_t header_checksum; + uint32_t src_ip; + uint32_t dst_ip; +} ipv4_hdr_t; + +void ipv4_cfg_init(); +void ipv4_set_cfg(const net_cfg_t *src); +const net_cfg_t* ipv4_get_cfg(); + +string ipv4_to_string(uint32_t ip); + +void ipv4_send_segment(uint32_t src_ip, + uint32_t dst_ip, + uint8_t proto, + sizedptr segment); + +void ip_input(uintptr_t ip_ptr, + uint32_t ip_len, + const uint8_t src_mac[6]); + +static inline uint32_t ipv4_network(uint32_t ip, uint32_t mask){ return ip & mask; } +static inline uint32_t ipv4_broadcast(uint32_t ip, uint32_t mask){ return (ip & mask) | ~mask; } +static inline uint32_t ipv4_first_host(uint32_t ip, uint32_t mask){ return (ip & mask) + 1; } +static inline uint32_t ipv4_last_host(uint32_t ip, uint32_t mask){ return ((ip & mask) | ~mask) - 1; } + +void ipv4_cfg_init(); +void ipv4_set_cfg(const net_cfg_t *src); +const net_cfg_t* ipv4_get_cfg(); +#ifdef __cplusplus +} +#endif diff --git a/shared/net/internet_layer/ipv4_route.c b/shared/net/internet_layer/ipv4_route.c new file mode 100644 index 00000000..b5729e5a --- /dev/null +++ b/shared/net/internet_layer/ipv4_route.c @@ -0,0 +1,66 @@ +#include "ipv4_route.h" +#include "std/memfunctions.h" + +static ipv4_rt_entry_t g_rt[IPV4_RT_MAX]; +static int g_rt_len = 0; + +void ipv4_rt_init() { + g_rt_len = 0; + memset(g_rt, 0, sizeof(g_rt)); +} + +bool ipv4_rt_add(uint32_t network, uint32_t mask, uint32_t gateway) +{ + if (g_rt_len >= IPV4_RT_MAX) return false; + + for (int i = 0; i < g_rt_len; ++i) { + if (g_rt[i].network == network && g_rt[i].mask == mask) { + g_rt[i].gateway = gateway; + return true; + } + } + g_rt[g_rt_len++] = (ipv4_rt_entry_t){ network, mask, gateway }; + return true; +} + +bool ipv4_rt_del(uint32_t network, uint32_t mask) +{ + for (int i = 0; i < g_rt_len; ++i) { + if (g_rt[i].network == network && g_rt[i].mask == mask) { + g_rt[i] = g_rt[--g_rt_len]; + memset(&g_rt[g_rt_len], 0, sizeof(g_rt[0])); + return true; + } + } + return false; +} + +static inline int prefix_len(uint32_t mask) +{ + int len = 0; + while (mask & 0x80000000U) { ++len; mask <<= 1; } + return len; +} + +bool ipv4_rt_lookup(uint32_t dst, uint32_t *next_hop) +{ + int best_len = -1; + uint32_t best_nh = 0; + + for (int i = 0; i < g_rt_len; ++i) { + uint32_t net = g_rt[i].network; + uint32_t mask = g_rt[i].mask; + if (mask && (dst & mask) == net) { + int l = prefix_len(mask); + if (l > best_len) { + best_len = l; + best_nh = g_rt[i].gateway ? g_rt[i].gateway : dst; + } + } + } + if (best_len >= 0) { + if (next_hop) *next_hop = best_nh; + return true; + } + return false; +} diff --git a/shared/net/internet_layer/ipv4_route.h b/shared/net/internet_layer/ipv4_route.h new file mode 100644 index 00000000..e030427b --- /dev/null +++ b/shared/net/internet_layer/ipv4_route.h @@ -0,0 +1,15 @@ +#pragma once +#include "types.h" + +#define IPV4_RT_MAX 8 + +typedef struct { + uint32_t network; + uint32_t mask; + uint32_t gateway; +} ipv4_rt_entry_t; + +void ipv4_rt_init(); +bool ipv4_rt_add(uint32_t network, uint32_t mask, uint32_t gateway); +bool ipv4_rt_del(uint32_t network, uint32_t mask); +bool ipv4_rt_lookup(uint32_t dst, uint32_t *next_hop); diff --git a/shared/net/ipv4.c b/shared/net/ipv4.c deleted file mode 100644 index 638b20c0..00000000 --- a/shared/net/ipv4.c +++ /dev/null @@ -1,37 +0,0 @@ -#include "ipv4.h" -#include "console/kio.h" -#include "network_types.h" -#include "std/string.h" -#include "std/memfunctions.h" - -uintptr_t create_ipv4_packet(uintptr_t p, uint32_t payload_len, uint8_t protocol, uint32_t source_ip, uint32_t destination_ip){ - ipv4_hdr_t* ip = (ipv4_hdr_t*)p; - ip->version_ihl = 0x45; - ip->dscp_ecn = 0; - ip->total_length = __builtin_bswap16(sizeof(ipv4_hdr_t) + payload_len); - ip->identification = 0; - ip->flags_frag_offset = __builtin_bswap16(0x4000); - ip->ttl = 64; - ip->protocol = protocol; - ip->src_ip = __builtin_bswap32(source_ip); - ip->dst_ip = __builtin_bswap32(destination_ip); - ip->header_checksum = checksum16((uint16_t*)ip, 10); - return p + sizeof(ipv4_hdr_t); -} - -uint8_t ipv4_get_protocol(uintptr_t ptr){ - return ((ipv4_hdr_t*)ptr)->protocol; -} - -void ipv4_populate_response(network_connection_ctx *ctx, eth_hdr_t *eth, ipv4_hdr_t* ipv4){ - ctx->ip = __builtin_bswap32(ipv4->src_ip); - memcpy(ctx->mac, eth->src_mac, 6); -} - -string ipv4_to_string(uint32_t ip){ - return string_format("%i.%i.%i.%i",(ip >> 24) & 0xFF,(ip >> 16) & 0xFF,(ip >> 8) & 0xFF,(ip >> 0) & 0xFF); -} - -uint32_t ipv4_get_source(uintptr_t ptr){ - return __builtin_bswap32(((ipv4_hdr_t*)ptr)->src_ip); -} \ No newline at end of file diff --git a/shared/net/ipv4.h b/shared/net/ipv4.h deleted file mode 100644 index ff1d4182..00000000 --- a/shared/net/ipv4.h +++ /dev/null @@ -1,33 +0,0 @@ -#pragma once - -#ifdef __cplusplus -extern "C" { -#endif - -#include "types.h" -#include "net/network_types.h" -#include "std/string.h" -#include "eth.h" - -typedef struct __attribute__((packed)) ipv4_hdr_t { - uint8_t version_ihl; - uint8_t dscp_ecn; - uint16_t total_length; - uint16_t identification; - uint16_t flags_frag_offset; - uint8_t ttl; - uint8_t protocol; - uint16_t header_checksum; - uint32_t src_ip; - uint32_t dst_ip; -} ipv4_hdr_t; - -uint8_t ipv4_get_protocol(uintptr_t ptr); -uintptr_t create_ipv4_packet(uintptr_t p, uint32_t payload_len, uint8_t protocol, uint32_t source_ip, uint32_t destination_ip); -void ipv4_populate_response(network_connection_ctx *ctx, eth_hdr_t *eth, ipv4_hdr_t* ipv4); -string ipv4_to_string(uint32_t ip); -uint32_t ipv4_get_source(uintptr_t ptr); - -#ifdef __cplusplus -} -#endif \ No newline at end of file diff --git a/shared/net/link_layer/arp.c b/shared/net/link_layer/arp.c new file mode 100644 index 00000000..18eac6db --- /dev/null +++ b/shared/net/link_layer/arp.c @@ -0,0 +1,216 @@ +#include "arp.h" +#include "eth.h" +#include "console/kio.h" +#include "std/memfunctions.h" +#include "net/internet_layer/ipv4.h" +#include "networking/network.h" +#include "process/scheduler.h" +#include "types.h" +#include "std/string.h" +#include "networking/network.h" + + +#define ARP_OPCODE_REQUEST 1 +#define ARP_OPCODE_REPLY 2 + +extern void sleep(uint64_t ms); +extern uintptr_t malloc(uint64_t size); +extern void free(void *ptr, uint64_t size); + +static uint16_t g_arp_pid = 0xFFFF; +static arp_entry_t g_arp_table[ARP_TABLE_MAX]; +static bool init = false; +#define KP(fmt, ...) \ + do { kprintf(fmt, ##__VA_ARGS__); } while (0) + +void arp_set_pid(uint16_t pid) { g_arp_pid = pid; } +uint16_t arp_get_pid() { return g_arp_pid; } + +void arp_table_init() { + KP("[ARP] init"); + memset(g_arp_table, 0, sizeof(g_arp_table)); + init = true; + arp_table_init_static_defaults(); +} + +void arp_table_init_static_defaults() { + uint8_t bmac[6] = {0xFF,0xFF,0xFF,0xFF,0xFF,0xFF}; + arp_table_put(0xFFFFFFFF, bmac, 0, true); +} + +static int arp_table_find_slot(uint32_t ip) { + for (int i = 0; i < ARP_TABLE_MAX; i++) { + if (g_arp_table[i].ip == ip) return i; + } + return -1; +} + +static int arp_table_find_free() { + for (int i = 0; i < ARP_TABLE_MAX; i++) { + if (g_arp_table[i].ip == 0) return i; + } + return -1; +} + +void arp_table_put(uint32_t ip, const uint8_t mac[6], uint32_t ttl_ms, bool is_static) { + int idx = arp_table_find_slot(ip); + if (idx < 0) idx = arp_table_find_free(); + if (idx < 0) idx = 0; + + g_arp_table[idx].ip = ip; + memcpy(g_arp_table[idx].mac, mac, 6); + g_arp_table[idx].ttl_ms = is_static ? 0 : ttl_ms; + g_arp_table[idx].static_entry = is_static ? 1 : 0; + + /*KP("[ARP] put %i.%i.%i.%i -> %x:%x:%x:%x:%x:%x static=%i ttl=%i\n", + (uint64_t)((ip>>24)&0xFF), (uint64_t)((ip>>16)&0xFF), + (uint64_t)((ip>>8)&0xFF), (uint64_t)(ip&0xFF), + (uint64_t)mac[0], (uint64_t)mac[1], (uint64_t)mac[2], + (uint64_t)mac[3], (uint64_t)mac[4], (uint64_t)mac[5], + (uint64_t)g_arp_table[idx].static_entry, (uint64_t)ttl_ms);*/ +} + +bool arp_table_get(uint32_t ip, uint8_t mac_out[6]) { + int idx = arp_table_find_slot(ip); + if (idx < 0) return false; + memcpy(mac_out, g_arp_table[idx].mac, 6); + return true; +} + +void arp_table_tick(uint32_t ms) { + for (int i = 0; i < ARP_TABLE_MAX; i++) { + if (g_arp_table[i].ip == 0 || g_arp_table[i].static_entry) + continue; + if (g_arp_table[i].ttl_ms <= ms) { + memset(&g_arp_table[i], 0, sizeof(arp_entry_t)); + } else { + g_arp_table[i].ttl_ms -= ms; + } + } +} + +bool arp_resolve(uint32_t ip, uint8_t mac_out[6], uint32_t timeout_ms) { + + if (arp_table_get(ip, mac_out)) return true; + if (ip == 0xFFFFFFFF) { + memset(mac_out, 0xFF, 6); + return true; + } + arp_send_request(ip); + + uint32_t waited = 0; + const uint32_t POLL_MS = 100; + while (waited < timeout_ms) { + arp_table_tick(POLL_MS); + if (arp_table_get(ip, mac_out)) return true; + sleep(POLL_MS); + waited += POLL_MS; + } + return false; +} + +void arp_send_request(uint32_t target_ip) { + const net_l2l3_endpoint *ep = network_get_local_endpoint(); + uint8_t dst_mac[6]; + arp_hdr_t hdr; + uintptr_t buf; + uint32_t len; + + memset(dst_mac, 0xFF, sizeof(dst_mac)); + memset(hdr.target_mac, 0x00, sizeof(hdr.target_mac)); + + hdr.htype = __builtin_bswap16(1); + hdr.ptype = __builtin_bswap16(0x0800); + hdr.hlen = 6; + hdr.plen = 4; + hdr.opcode = __builtin_bswap16(1); + memcpy(hdr.sender_mac, ep->mac, 6); + hdr.sender_ip = __builtin_bswap32(ep->ip); + hdr.target_ip = __builtin_bswap32(target_ip); + + len = sizeof(eth_hdr_t) + sizeof(arp_hdr_t); + buf = (uintptr_t)malloc(len); + if (!buf) return; + + uintptr_t ptr = create_eth_packet(buf, ep->mac, dst_mac, 0x0806); + memcpy((void*)ptr, &hdr, sizeof(arp_hdr_t)); + + eth_send_frame(buf, len); + free((void*)buf, len); +} + +bool arp_should_handle(const arp_hdr_t *arp, uint32_t my_ip) { + return __builtin_bswap32(arp->target_ip) == my_ip; +} + +void arp_populate_response(net_l2l3_endpoint *ep, const arp_hdr_t *arp) { + memcpy(ep->mac, arp->sender_mac, 6); + ep->ip = __builtin_bswap32(arp->sender_ip); +} + +bool arp_can_reply() { + const net_cfg_t *cfg = ipv4_get_cfg(); + return (cfg && cfg->ip != 0 && cfg->mode != NET_MODE_DISABLED); +} + +void arp_daemon_entry() { + while (1){ + const net_cfg_t *cfg = ipv4_get_cfg(); + if(cfg && cfg->ip != 0 && cfg->mode != NET_MODE_DISABLED) break; + sleep(200); + } + arp_table_init(); + + while (1) { + arp_table_tick(1000); + sleep(1000); + } +} +static void arp_send_reply(const arp_hdr_t *in_arp, + const uint8_t in_src_mac[6], + uint32_t frame_len) { + const net_l2l3_endpoint *ep = network_get_local_endpoint(); + + uint32_t len = sizeof(eth_hdr_t) + sizeof(arp_hdr_t); + uintptr_t buf = (uintptr_t)malloc(len); + if (!buf) return; + + uintptr_t ptr = create_eth_packet(buf, + ep->mac, + in_src_mac, + 0x0806); + + arp_hdr_t reply = *in_arp; + memcpy(reply.target_mac, in_arp->sender_mac, 6); + memcpy(reply.sender_mac, ep->mac, 6); + reply.target_ip = in_arp->sender_ip; + reply.sender_ip = __builtin_bswap32(ep->ip); + reply.opcode = __builtin_bswap16(ARP_OPCODE_REPLY); + + memcpy((void*)ptr, &reply, sizeof(reply)); + + eth_send_frame(buf, len); + free((void*)buf, len); +} + + +void arp_input(uintptr_t frame_ptr, uint32_t frame_len) { + if (frame_len < sizeof(eth_hdr_t) + sizeof(arp_hdr_t)) return; + + if(!init) return; + + arp_hdr_t *hdr = (arp_hdr_t*)(frame_ptr + sizeof(eth_hdr_t)); + uint32_t sender_ip = __builtin_bswap32(hdr->sender_ip); + + arp_table_put(sender_ip, hdr->sender_mac, 180000, false); + + const net_l2l3_endpoint *ep = network_get_local_endpoint(); + if (__builtin_bswap16(hdr->opcode) == ARP_OPCODE_REQUEST && + arp_should_handle(hdr, ep->ip) && + arp_can_reply()) + { + const arp_hdr_t *hdr = (arp_hdr_t*)(frame_ptr + sizeof(eth_hdr_t)); + const uint8_t *src_mac = hdr->sender_mac; + arp_send_reply(hdr, src_mac, frame_len); + } +} \ No newline at end of file diff --git a/shared/net/link_layer/arp.h b/shared/net/link_layer/arp.h new file mode 100644 index 00000000..4e5454a3 --- /dev/null +++ b/shared/net/link_layer/arp.h @@ -0,0 +1,54 @@ +#pragma once +#include "types.h" +#include "net/network_types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct __attribute__((packed)) arp_hdr_t { + uint16_t htype; + uint16_t ptype; + uint8_t hlen; + uint8_t plen; + uint16_t opcode;//1=request, 2=reply + uint8_t sender_mac[6]; + uint32_t sender_ip; + uint8_t target_mac[6]; + uint32_t target_ip; +} arp_hdr_t; + +bool arp_should_handle(const arp_hdr_t *arp, uint32_t my_ip); +void arp_populate_response(net_l2l3_endpoint *ep, const arp_hdr_t *arp); +bool arp_resolve(uint32_t ip, uint8_t mac_out[6], uint32_t timeout_ms); + +#define ARP_TABLE_MAX 64 + +typedef struct arp_entry { + uint32_t ip; + uint8_t mac[6]; + uint32_t ttl_ms; + uint8_t static_entry;//1 static, 0 dynamic +} arp_entry_t; + +void arp_table_init(); + +void arp_table_put(uint32_t ip, const uint8_t mac[6], uint32_t ttl_ms, bool is_static); + +bool arp_table_get(uint32_t ip, uint8_t mac_out[6]); + +void arp_table_tick(uint32_t ms); + +void arp_table_init_static_defaults(); + +void arp_send_request(uint32_t target_ip); + +void arp_daemon_entry(); +bool arp_can_reply(); +void arp_daemon_entry(); +void arp_set_pid(uint16_t pid); +uint16_t arp_get_pid(); +void arp_input(uintptr_t frame_ptr, uint32_t frame_len); +#ifdef __cplusplus +} +#endif diff --git a/shared/net/link_layer/eth.c b/shared/net/link_layer/eth.c new file mode 100644 index 00000000..ef6af15d --- /dev/null +++ b/shared/net/link_layer/eth.c @@ -0,0 +1,55 @@ +#include "eth.h" +#include "arp.h" +#include "std/memfunctions.h" +#include "net/internet_layer/ipv4.h" +extern int net_tx_frame(uintptr_t frame_ptr, uint32_t frame_len); + +uintptr_t create_eth_packet(uintptr_t p, + const uint8_t src_mac[6], + const uint8_t dst_mac[6], + uint16_t type) +{ + eth_hdr_t* eth =(eth_hdr_t*)p; + memcpy(eth->src_mac, src_mac, 6); + memcpy(eth->dst_mac, dst_mac, 6); + eth->ethertype = __builtin_bswap16(type); + return p + sizeof(eth_hdr_t); +} + +uint16_t eth_parse_packet_type(uintptr_t ptr) { + const eth_hdr_t* eth = (const eth_hdr_t*)ptr; + return __builtin_bswap16(eth->ethertype); +} + +const uint8_t* eth_get_source_mac(uintptr_t ptr) { + const eth_hdr_t* eth = (const eth_hdr_t*)ptr; + return eth->src_mac; +} +const uint8_t* eth_get_source(uintptr_t ptr){ + const eth_hdr_t* eth = (const eth_hdr_t*)ptr; + return eth->src_mac; +} + +bool eth_send_frame(uintptr_t frame_ptr, uint32_t frame_len){ + return net_tx_frame(frame_ptr, frame_len); +} + +void eth_input(uintptr_t frame_ptr, uint32_t frame_len) { + if (frame_len < sizeof(eth_hdr_t)) return; + + uint16_t type = eth_parse_packet_type(frame_ptr); + const uint8_t* src_mac = eth_get_source_mac(frame_ptr); + uintptr_t payload_ptr = frame_ptr + sizeof(eth_hdr_t); + uint32_t payload_len = frame_len - sizeof(eth_hdr_t); + + switch (type) { + case 0x0806: + arp_input(frame_ptr, frame_len); + break; + case 0x0800: + ip_input(payload_ptr, payload_len, src_mac); + break; + default: + break; + } +} diff --git a/shared/net/link_layer/eth.h b/shared/net/link_layer/eth.h new file mode 100644 index 00000000..c3eacb8c --- /dev/null +++ b/shared/net/link_layer/eth.h @@ -0,0 +1,29 @@ +#pragma once +#include "types.h" +#include "net/network_types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct __attribute__((packed)) eth_hdr_t { + uint8_t dst_mac[6]; + uint8_t src_mac[6]; + uint16_t ethertype; +} eth_hdr_t; + +uintptr_t create_eth_packet(uintptr_t ptr, + const uint8_t src_mac[6], + const uint8_t dst_mac[6], + uint16_t type); + +uint16_t eth_parse_packet_type(uintptr_t ptr); + +const uint8_t* eth_get_source(uintptr_t ptr); + +bool eth_send_frame(uintptr_t frame_ptr, uint32_t frame_len); +void eth_input(uintptr_t frame_ptr, uint32_t frame_len); + +#ifdef __cplusplus +} +#endif diff --git a/shared/net/net.h b/shared/net/net.h new file mode 100644 index 00000000..fe521f39 --- /dev/null +++ b/shared/net/net.h @@ -0,0 +1,4 @@ +#pragma once + +#define FORMAT_IP(ipv4) (((ipv4 >> 24) & 0xFF), ((ipv4 >> 16) & 0xFF), ((ipv4 >> 8) & 0xFF), ((ipv4) & 0xFF)) +#define IP_ENCODE(ip1,ip2,ip3,ip4) (((ipv4 << 24) & 0xFF) | ((ipv4 << 16) & 0xFF) | ((ipv4 << 8) & 0xFF) | ((ipv4) & 0xFF)) \ No newline at end of file diff --git a/shared/net/network_types.h b/shared/net/network_types.h index a3c54978..57b706f6 100644 --- a/shared/net/network_types.h +++ b/shared/net/network_types.h @@ -14,16 +14,15 @@ typedef enum NetProtocol { ICMP } NetProtocol; -uint16_t checksum16(uint16_t *data, size_t len); +typedef struct net_l2l3_endpoint { + uint8_t mac[6]; + uint32_t ip; //rn ipv4 only +} net_l2l3_endpoint; -uint16_t checksum16_pipv4(uint32_t src_ip, uint32_t dst_ip, uint8_t protocol, const uint8_t* payload, uint16_t length); - -typedef struct network_connection_ctx { - uint16_t port; +typedef struct net_l4_endpoint { uint32_t ip; - uint8_t mac[6]; -} network_connection_ctx; - + uint16_t port; +} net_l4_endpoint; #ifdef __cplusplus } #endif \ No newline at end of file diff --git a/shared/net/tcp.c b/shared/net/tcp.c deleted file mode 100644 index d244fa85..00000000 --- a/shared/net/tcp.c +++ /dev/null @@ -1,198 +0,0 @@ -#include "tcp.h" -#include "console/kio.h" -#include "net/network_types.h" -#include "network_types.h" -#include "syscalls/syscalls.h" -#include "std/memfunctions.h" -#include "eth.h" -#include "ipv4.h" -#include "math/math.h" - -void create_tcp_packet(uintptr_t p, network_connection_ctx source, network_connection_ctx destination, sizedptr payload){ - p = create_eth_packet(p, source.mac, destination.mac, 0x800); - - tcp_data *data = (tcp_data*)payload.ptr; - - size_t full_size = sizeof(tcp_hdr_t) + data->options.size + data->payload.size; - - p = create_ipv4_packet(p, full_size, 0x06, source.ip, destination.ip); - - tcp_hdr_t* tcp = (tcp_hdr_t*)p; - tcp->src_port = __builtin_bswap16(source.port); - tcp->dst_port = __builtin_bswap16(destination.port); - - if (payload.size != sizeof(tcp_data)){ - printf("[TCP Packet creation error] wrong payload size %i (expected %i)",payload.size, sizeof(tcp_data)); - } - - memcpy((void*)&tcp->sequence,(void*)payload.ptr, 12); - - p += sizeof(tcp_hdr_t); - - memcpy((void*)p,(void*)data->options.ptr, data->options.size); - - p += data->options.size; - - tcp->data_offset_reserved = ((sizeof(tcp_hdr_t) + data->options.size + 3) / 4) << 4; - - memcpy((void*)p,(void*)data->payload.ptr, data->payload.size); - - tcp->checksum = __builtin_bswap16(checksum16_pipv4(source.ip,destination.ip,0x06,(uint8_t*)tcp, full_size)); -} - -sizedptr tcp_parse_packet_payload(uintptr_t ptr){ - eth_hdr_t* eth = (eth_hdr_t*)ptr; - - ptr += sizeof(eth_hdr_t); - - if (__builtin_bswap16(eth->ethertype) == 0x800){ - ipv4_hdr_t* ip = (ipv4_hdr_t*)ptr; - uint32_t srcip = __builtin_bswap32(ip->src_ip); - ptr += sizeof(ipv4_hdr_t); - if (ip->protocol == 0x06){ - tcp_hdr_t* tcp = (tcp_hdr_t*)ptr; - return (sizedptr){ptr,__builtin_bswap16(ip->total_length) - sizeof(ipv4_hdr_t)}; - } - } - - return (sizedptr){0,0}; -} - -void tcp_send(uint16_t port, network_connection_ctx *destination, tcp_data* data){ - send_packet(TCP, port, destination, data, sizeof(tcp_data)); - if ((data->flags & ~(1 << ACK_F)) != 0 || data->payload.size > 0){ - data->sequence += __builtin_bswap32(max(1,data->payload.size)); - } - data->expected_ack = __builtin_bswap32(data->sequence); -} - -void tcp_reset(uint16_t port, network_connection_ctx *destination, tcp_data* data){ - data->flags = (1 << RST_F) | (1 << ACK_F); - tcp_send(port, destination, data); -} - -bool tcp_expect_response(sizedptr *pack){ - uint16_t timeout = 10; - while (!read_packet(pack)){ - sleep(1000); - if (timeout-- == 0){ - printf("Response timeout"); - return false; - } - } - return true; -} - -uint8_t tcp_check_response(tcp_data *data, sizedptr *out){ - - sizedptr pack; - - if (!tcp_expect_response(&pack) || !pack.ptr){ - printf("Response timeout. Retrying"); - return TCP_RETRY; - } - - sizedptr payload = tcp_parse_packet_payload(pack.ptr); - if (!payload.ptr) { - printf("Wrong payload pointer. Retrying"); - return TCP_RETRY; - } - - tcp_hdr_t *response = (tcp_hdr_t*)payload.ptr; - - uint32_t ack = __builtin_bswap32(response->ack); - uint32_t seq = __builtin_bswap32(response->sequence); - - size_t hdr_size = (response->data_offset_reserved >> 4) * 4; - size_t payload_size = payload.size - hdr_size; - data->ack = __builtin_bswap32(seq+max(1,payload_size)); - - if (ack != data->expected_ack){ - printf("Wrong ack %i vs %i. Resetting", ack, data->expected_ack); - return TCP_RESET; - } - - if (response->flags != (data->flags | (1 << ACK_F))){ - printf("Wrong flags %b vs %b. Resetting",response->flags, data->flags | (1 << ACK_F)); - return TCP_RESET; - } - - if (out){ - out->ptr = payload.ptr + hdr_size; - out->size = payload_size; - } - - return TCP_OK; -} - -bool tcp_handskake(network_connection_ctx *dest, uint16_t port, tcp_data *data, uint8_t retry){ - if (retry == 5){ - printf("Exceeded max number of retries"); - return false; - } - - data->sequence = 0; - data->ack = 0; - data->flags = (1 << SYN_F); - - tcp_send(port, dest, data); - - uint8_t resp = tcp_check_response(data, 0); - if (resp == TCP_RETRY){ - sleep(1000); - return tcp_handskake(dest, port, data, retry+1); - } else if (resp == TCP_RESET){ - tcp_reset(port, dest, data); - return false; - } - - data->flags = (1 << ACK_F); - - tcp_send(port, dest, data); - - printf("Acknowledgement of acknowledgemnt sent"); - - return true; -} - -bool tcp_close(network_connection_ctx *dest, uint16_t port, tcp_data *data, uint8_t retry, uint32_t orig_seq, uint32_t orig_ack){ - if (retry == 5){ - printf("Exceeded max number of retries"); - return false; - } - - data->sequence = orig_seq; - data->ack = orig_ack; - data->flags = (1 << FIN_F) | (1 << ACK_F); - - tcp_send(port, dest, data); - - data->flags = (1 << ACK_F); - uint8_t resp = tcp_check_response(data, 0); - if (resp == TCP_RETRY){ - sleep(1000); - return tcp_handskake(dest, port, data, retry+1); - } else if (resp == TCP_RESET){ - tcp_reset(port, dest, data); - return false; - } - - data->flags = (1 << FIN_F); - - resp = tcp_check_response(data, 0); - if (resp == TCP_RETRY){ - sleep(1000); - return tcp_handskake(dest, port, data, retry+1); - } else if (resp == TCP_RESET){ - tcp_reset(port, dest, data); - return true; - } - - data->flags = (1 << ACK_F); - - tcp_send(port, dest, data); - - printf("Connection closed"); - - return true; -} \ No newline at end of file diff --git a/shared/net/tcp.h b/shared/net/tcp.h deleted file mode 100644 index fafa7fdd..00000000 --- a/shared/net/tcp.h +++ /dev/null @@ -1,61 +0,0 @@ -#pragma once - -#ifdef __cplusplus -extern "C" { -#endif - -#include "types.h" -#include "net/network_types.h" - -#define FIN_F 0 -#define SYN_F 1 -#define RST_F 2 -#define PSH_F 3 -#define ACK_F 4 -#define URG_F 5 -#define ECE_F 6 -#define CWR_F 7 - -//TODO: more response types. Indicate why we reset or retry instead of just that -#define TCP_RESET 2 -#define TCP_RETRY 1 -#define TCP_OK 0 - -typedef struct __attribute__((packed)) tcp_hdr_t { - uint16_t src_port; - uint16_t dst_port; - uint32_t sequence; - uint32_t ack; - uint8_t data_offset_reserved;// upper offset, lower reserved - uint8_t flags; - uint16_t window; - uint16_t checksum; - uint16_t urgent_ptr; -} tcp_hdr_t; - -typedef struct tcp_data { - uint32_t sequence; - uint32_t ack; - uint8_t padding; - uint8_t flags; - uint16_t window; - sizedptr options; - sizedptr payload; - uint32_t expected_ack; -} tcp_data; - -void create_tcp_packet(uintptr_t p, network_connection_ctx source, network_connection_ctx destination, sizedptr payload); -size_t calc_tcp_size(uint16_t payload_len); -uint16_t tcp_parse_packet(uintptr_t ptr); -sizedptr tcp_parse_packet_payload(uintptr_t ptr); - -void tcp_send(uint16_t port, network_connection_ctx *destination, tcp_data* data); -void tcp_reset(uint16_t port, network_connection_ctx *destination, tcp_data* data); -bool expect_response(sizedptr *pack); -uint8_t tcp_check_response(tcp_data *data, sizedptr *out); -bool tcp_handskake(network_connection_ctx *dest, uint16_t port, tcp_data *data, uint8_t retry); -bool tcp_close(network_connection_ctx *dest, uint16_t port, tcp_data *data, uint8_t retry, uint32_t orig_seq, uint32_t orig_ack); - -#ifdef __cplusplus -} -#endif \ No newline at end of file diff --git a/shared/net/transport_layer/csocket_tcp.cpp b/shared/net/transport_layer/csocket_tcp.cpp new file mode 100644 index 00000000..af408fb6 --- /dev/null +++ b/shared/net/transport_layer/csocket_tcp.cpp @@ -0,0 +1,74 @@ +#pragma once +#include "net/transport_layer/socket_tcp.hpp" +#include "net/transport_layer/socket.hpp" +#include "csocket_tcp.h" + +extern "C" { + +socket_handle_t socket_tcp_create(uint8_t role, uint32_t pid) { + return reinterpret_cast(new TCPSocket(role, pid)); +} + +int32_t socket_bind_tcp(socket_handle_t sh, uint16_t port) { + return reinterpret_cast(sh)->bind(port); +} + +int32_t socket_listen_tcp(socket_handle_t sh, int32_t backlog) { + return reinterpret_cast(sh)->listen(backlog); +} + +socket_handle_t socket_accept_tcp(socket_handle_t sh) { + TCPSocket* srv = reinterpret_cast(sh); + TCPSocket* client = srv->accept(); + return reinterpret_cast(client); +} + +int32_t socket_connect_tcp(socket_handle_t sh, uint32_t ip, uint16_t port) { + return reinterpret_cast(sh)->connect(ip, port); +} + +int64_t socket_send_tcp(socket_handle_t sh, const void* buf, uint64_t len) { + return reinterpret_cast(sh)->send(buf, len); +} + +int64_t socket_recv_tcp(socket_handle_t sh, void* buf, uint64_t len) { + return reinterpret_cast(sh)->recv(buf, len); +} + +int32_t socket_close_tcp(socket_handle_t sh) { + return reinterpret_cast(sh)->close(); +} + +void socket_destroy_tcp(socket_handle_t sh) { + delete reinterpret_cast(sh); +} + +uint16_t socket_get_local_port_tcp(socket_handle_t sh) { + return reinterpret_cast(sh)->get_local_port(); +} + +uint32_t socket_get_remote_ip_tcp(socket_handle_t sh) { + return reinterpret_cast(sh)->get_remote_ip(); +} + +uint16_t socket_get_remote_port_tcp(socket_handle_t sh) { + return reinterpret_cast(sh)->get_remote_port(); +} + +uint8_t socket_get_protocol_tcp(socket_handle_t sh) { + return reinterpret_cast(sh)->get_protocol(); +} + +uint8_t socket_get_role_tcp(socket_handle_t sh) { + return reinterpret_cast(sh)->get_role(); +} + +bool socket_is_bound_tcp(socket_handle_t sh) { + return reinterpret_cast(sh)->is_bound(); +} + +bool socket_is_connected_tcp(socket_handle_t sh) { + return reinterpret_cast(sh)->is_connected(); +} + +} diff --git a/shared/net/transport_layer/csocket_tcp.h b/shared/net/transport_layer/csocket_tcp.h new file mode 100644 index 00000000..300ae33a --- /dev/null +++ b/shared/net/transport_layer/csocket_tcp.h @@ -0,0 +1,31 @@ +#pragma once + +#include "types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void* socket_handle_t; + +socket_handle_t socket_tcp_create(uint8_t role, uint32_t pid); +int32_t socket_bind_tcp(socket_handle_t sh, uint16_t port); +int32_t socket_listen_tcp(socket_handle_t sh, int32_t backlog); +socket_handle_t socket_accept_tcp(socket_handle_t sh); +int32_t socket_connect_tcp(socket_handle_t sh, uint32_t ip, uint16_t port); +int64_t socket_send_tcp(socket_handle_t sh, const void* buf, uint64_t len); +int64_t socket_recv_tcp(socket_handle_t sh, void* buf, uint64_t len); +int32_t socket_close_tcp(socket_handle_t sh); +void socket_destroy_tcp(socket_handle_t sh); + +uint16_t socket_get_local_port_tcp(socket_handle_t sh); +uint32_t socket_get_remote_ip_tcp(socket_handle_t sh); +uint16_t socket_get_remote_port_tcp(socket_handle_t sh); +uint8_t socket_get_protocol_tcp(socket_handle_t sh); +uint8_t socket_get_role_tcp(socket_handle_t sh); +bool socket_is_bound_tcp(socket_handle_t sh); +bool socket_is_connected_tcp(socket_handle_t sh); + +#ifdef __cplusplus +} +#endif diff --git a/shared/net/transport_layer/csocket_udp.cpp b/shared/net/transport_layer/csocket_udp.cpp new file mode 100644 index 00000000..ea2ae548 --- /dev/null +++ b/shared/net/transport_layer/csocket_udp.cpp @@ -0,0 +1,62 @@ +#pragma once +#include "net/transport_layer/socket_udp.hpp" +#include "net/transport_layer/socket.hpp" +#include "csocket_udp.h" + +extern "C" socket_handle_t udp_socket_create(uint8_t role, uint32_t pid) { + return reinterpret_cast(new UDPSocket(role, pid)); +} + +extern "C" int32_t socket_bind_udp(socket_handle_t sh, uint16_t port) { + return reinterpret_cast(sh)->bind(port); +} + +extern "C" int64_t socket_sendto_udp(socket_handle_t sh, + uint32_t ip, uint16_t port, + const void* buf, uint64_t len) { + auto sock = reinterpret_cast(sh); + return sock->sendto(ip, port, buf, len); +} + +extern "C" int64_t socket_recvfrom_udp(socket_handle_t sh, + void* buf, uint64_t len, + uint32_t* out_ip, uint16_t* out_port) { + auto sock = reinterpret_cast(sh); + return sock->recvfrom(buf, len, out_ip, out_port); +} + +extern "C" int32_t socket_close_udp(socket_handle_t sh) { + return reinterpret_cast(sh)->close(); +} + +extern "C" void socket_destroy_udp(socket_handle_t sh) { + delete reinterpret_cast(sh); +} + +extern "C" uint16_t socket_get_local_port_udp(socket_handle_t sh) { + return reinterpret_cast(sh)->get_local_port(); +} + +extern "C" uint16_t socket_get_remote_port_udp(socket_handle_t sh) { + return reinterpret_cast(sh)->get_remote_port(); +} + +extern "C" uint32_t socket_get_remote_ip_udp(socket_handle_t sh) { + return reinterpret_cast(sh)->get_remote_ip(); +} + +extern "C" uint8_t socket_get_protocol_udp(socket_handle_t sh) { + return reinterpret_cast(sh)->get_protocol(); +} + +extern "C" uint8_t socket_get_role_udp(socket_handle_t sh) { + return reinterpret_cast(sh)->get_role(); +} + +extern "C" bool socket_is_bound_udp(socket_handle_t sh) { + return reinterpret_cast(sh)->is_bound(); +} + +extern "C" bool socket_is_connected_udp(socket_handle_t sh) { + return reinterpret_cast(sh)->is_connected(); +} diff --git a/shared/net/transport_layer/csocket_udp.h b/shared/net/transport_layer/csocket_udp.h new file mode 100644 index 00000000..dd37c1ec --- /dev/null +++ b/shared/net/transport_layer/csocket_udp.h @@ -0,0 +1,41 @@ +#pragma once +#include "types.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef void* socket_handle_t; + +socket_handle_t udp_socket_create(uint8_t role, uint32_t pid); + +int32_t socket_bind_udp(socket_handle_t sh, uint16_t port); + +int64_t socket_sendto_udp(socket_handle_t sh, + uint32_t ip, uint16_t port, + const void* buf, uint64_t len); + +int64_t socket_recvfrom_udp(socket_handle_t sh, + void* buf, uint64_t len, + uint32_t* out_ip, uint16_t* out_port); + +int32_t socket_close_udp(socket_handle_t sh); + +void socket_destroy_udp(socket_handle_t sh); + +uint16_t socket_get_local_port_udp(socket_handle_t sh); + +uint16_t socket_get_remote_port_udp(socket_handle_t sh); + +uint32_t socket_get_remote_ip_udp(socket_handle_t sh); + +uint8_t socket_get_protocol_udp(socket_handle_t sh); + +uint8_t socket_get_role_udp(socket_handle_t sh); + +bool socket_is_bound_udp(socket_handle_t sh); + +bool socket_is_connected_udp(socket_handle_t sh); +#ifdef __cplusplus +} +#endif diff --git a/shared/net/transport_layer/socket.hpp b/shared/net/transport_layer/socket.hpp new file mode 100644 index 00000000..0454a409 --- /dev/null +++ b/shared/net/transport_layer/socket.hpp @@ -0,0 +1,80 @@ +#pragma once + +#include "types.h" +#include "net/network_types.h" +#include "networking/port_manager.h" +#include "tcp.h" +#include "udp.h" + +#ifdef __cplusplus +extern "C" { +#endif + +//protos +#define PROTO_TCP 1 +#define PROTO_UDP 2 + +//roles +#define SOCK_ROLE_CLIENT 0 +#define SOCK_ROLE_SERVER 1 + +#define SOCK_OK 0 +#define SOCK_ERR_INVAL -1 +#define SOCK_ERR_BOUND -2 +#define SOCK_ERR_NOT_BOUND -3 +#define SOCK_ERR_PERM -4 +#define SOCK_ERR_NO_PORT -5 +#define SOCK_ERR_SYS -6 +#define SOCK_ERR_PROTO -7 +#define SOCK_ERR_STATE -8 + +#ifdef __cplusplus +} +#endif + + +#ifdef __cplusplus + +class Socket { +protected: + uint16_t localPort = 0; + uint32_t remoteIP = 0; + uint16_t remotePort = 0; + uint8_t proto; + uint8_t role; + bool bound = false; + bool connected = false; + uint16_t pid = 0; + + Socket(uint8_t protocol, uint8_t r) + : proto(protocol), role(r) {} + +public: + virtual ~Socket() { close(); } + + virtual int32_t bind(uint16_t port) = 0; + + virtual int32_t close() { + if (bound) { + if (proto == PROTO_UDP) { + udp_unbind(localPort, pid); + } else if (proto == PROTO_TCP) { + tcp_unbind(localPort, pid); + } + bound = false; + localPort = 0; + } + connected = false; + return SOCK_OK; + } + + uint16_t get_local_port() const { return localPort; } + uint16_t get_remote_port() const { return remotePort; } + uint32_t get_remote_ip() const { return remoteIP; } + uint8_t get_protocol() const { return proto; } + uint8_t get_role() const { return role; } + bool is_bound() const { return bound; } + bool is_connected() const { return connected; } +}; + +#endif diff --git a/shared/net/transport_layer/socket_tcp.hpp b/shared/net/transport_layer/socket_tcp.hpp new file mode 100644 index 00000000..59363d5c --- /dev/null +++ b/shared/net/transport_layer/socket_tcp.hpp @@ -0,0 +1,230 @@ +#pragma once + +#include "console/kio.h" +#include "std/string.h" +#include "net/internet_layer/ipv4.h" +#include "std/memfunctions.h" +#include "socket.hpp" +#include "net/transport_layer/tcp.h" +#include "types.h" +#include "data_struct/ring_buffer.hpp" + +#define KP(fmt, ...) \ + do { kprintf(fmt, ##__VA_ARGS__); } while (0) + +extern "C" { + void sleep(uint64_t ms); + uintptr_t malloc(uint64_t size); + void free(void *ptr, uint64_t size); +} + +static constexpr int TCP_MAX_BACKLOG = 8; + +class TCPSocket : public Socket { + + inline static TCPSocket* s_by_port[MAX_PORTS] = { nullptr }; + inline static TCPSocket* s_list_head = nullptr; + + static constexpr int TCP_RING_CAP = 1024; + RingBuffer ring; + tcp_data* flow = nullptr; + + TCPSocket* pending[TCP_MAX_BACKLOG] = { nullptr }; + int backlogCap = 0; + int backlogLen = 0; + + static void dispatch(uintptr_t ptr, + uint32_t len, + uint32_t src_ip, + uint16_t src_port, + uint16_t dst_port) + { + if (len == 0) { + TCPSocket* srv = s_by_port[dst_port]; + if (!srv || srv->role != SOCK_ROLE_SERVER) return; + if (srv->backlogLen >= srv->backlogCap) return; + + TCPSocket* child = new TCPSocket(); + child->localPort = dst_port; + child->remoteIP = src_ip; + child->remotePort = src_port; + child->bound = true; + child->connected = true; + child->flow = tcp_get_ctx(dst_port, src_ip, src_port); + child->pid = srv->pid; + child->insert_in_global_list(); + srv->pending[srv->backlogLen++] = child; + return; + } + + for (TCPSocket* s = s_list_head; s; s = s->next) { + if (s->connected && + s->localPort == dst_port && + s->remoteIP == src_ip && + s->remotePort == src_port) + { + s->on_receive(ptr, len); + return; + } + } + } + + void on_receive(uintptr_t ptr, uint32_t len) { + auto data = reinterpret_cast(malloc(len)); + if (!data) return; + + memcpy(data, (void*)ptr, len); + sizedptr packet = { (uintptr_t)data, len }; + + if (!ring.push(packet)) { + sizedptr dropped; + ring.pop(dropped); + free((void*)dropped.ptr, dropped.size); + ring.push(packet); + } + } + + void insert_in_global_list() { + next = s_list_head; + s_list_head = this; + } + + void remove_from_global_list() { + TCPSocket** cur = &s_list_head; + while (*cur) { + if (*cur == this) { + *cur = (*cur)->next; + break; + } + cur = &((*cur)->next); + } + } + + TCPSocket* next =nullptr; + +public: + explicit TCPSocket(uint8_t r = SOCK_ROLE_CLIENT, uint32_t pid_ = 0) + : Socket(PROTO_TCP, r) + { + if (pid_ != 0) { + pid = pid_; + insert_in_global_list(); + } + } + + ~TCPSocket() override { + close(); + remove_from_global_list(); + } + + int32_t bind(uint16_t port) override { + if (role != SOCK_ROLE_SERVER) return SOCK_ERR_PERM; + if (bound) return SOCK_ERR_BOUND; + if (!tcp_bind(port, pid, dispatch)) + return SOCK_ERR_SYS; + s_by_port[port] = this; + localPort = port; bound = true; + return SOCK_OK; + } + + int32_t connect(uint32_t ip, uint16_t port) { + if (role != SOCK_ROLE_CLIENT) return SOCK_ERR_PERM; + if (connected) return SOCK_ERR_STATE; + + int p = tcp_alloc_ephemeral(pid, dispatch); + if (p < 0) return SOCK_ERR_NO_PORT; + s_by_port[p] = this; + localPort = p; bound = true; + + tcp_data ctx_copy{}; + net_l4_endpoint dst{ip, port}; + if (!tcp_handshake(p, &dst, &ctx_copy, 0)) return SOCK_ERR_SYS; + + flow = tcp_get_ctx(p, ip, port); + if (!flow) return SOCK_ERR_SYS; + + remoteIP = ip; + remotePort = port; + connected = true; + return SOCK_OK; + } + + int64_t send(const void* buf, uint64_t len) { + if (!connected || !flow) return SOCK_ERR_STATE; + flow->payload = { (uintptr_t)buf, (uint32_t)len }; + flow->flags = (1< TCP_MAX_BACKLOG ? TCP_MAX_BACKLOG : max_backlog; + backlogLen = 0; + return SOCK_OK; + } + + TCPSocket* accept() { + const int max_iters = 100; + int iter = 0; + while (backlogLen == 0) { + if (++iter > max_iters) return nullptr; + sleep(10); + } + TCPSocket* client = pending[0]; + for (int i = 1; i < backlogLen; ++i) + pending[i - 1] = pending[i]; + pending[--backlogLen] = nullptr; + return client; + } + + int32_t close_client() { + + if (connected && flow) { + tcp_flow_close(flow); + connected = false; + } + if (bound) { + if (s_by_port[localPort] == this) { + tcp_unbind(localPort, pid); + s_by_port[localPort] = nullptr; + } + bound = false; + } + sizedptr pkt; + while (ring.pop(pkt)) { + free((void*)pkt.ptr, pkt.size); + } + return SOCK_OK; + } + + int32_t close_server() { + if (bound) { + tcp_unbind(localPort, pid); + s_by_port[localPort] = nullptr; + bound = false; + } + for (int i = 0; i < backlogLen; ++i) { + delete pending[i]; + } + backlogLen = 0; + sizedptr pkt; + while (ring.pop(pkt)) { + free((void*)pkt.ptr, pkt.size); + } + return SOCK_OK; + } + + int32_t close() override { + return role == SOCK_ROLE_SERVER ? close_server() : close_client(); + } +}; diff --git a/shared/net/transport_layer/socket_udp.hpp b/shared/net/transport_layer/socket_udp.hpp new file mode 100644 index 00000000..3a9640c5 --- /dev/null +++ b/shared/net/transport_layer/socket_udp.hpp @@ -0,0 +1,143 @@ +#pragma once + +#include "socket.hpp" +#include "net/transport_layer/udp.h" +#include "types.h" +#include "std/string.h" +#include "net/internet_layer/ipv4.h" +#include "std/memfunctions.h" + +extern "C" { + void sleep(uint64_t ms); + uintptr_t malloc(uint64_t size); + void free(void *ptr, uint64_t size); +} + +static constexpr int32_t UDP_RING_CAP = 1024; + +class UDPSocket : public Socket { + static UDPSocket* s_by_port[MAX_PORTS]; + + sizedptr ring[UDP_RING_CAP]; + uint32_t src_ips[UDP_RING_CAP]; + uint16_t src_ports[UDP_RING_CAP]; + + int32_t r_head = 0, r_tail = 0; + + static void dispatch(uintptr_t ptr, uint32_t len, //b UDPSocket::dispatch + uint32_t src_ip, uint16_t src_port, uint16_t dst_port) { + auto *sock = s_by_port[dst_port]; + if (!sock) return; + + if (sock->remotePort != 0 && sock->remotePort != src_port)return; + + //if (sock->remoteIP != 0 && sock->remoteIP != src_ip)return; + + sock->on_receive(ptr, len, src_ip, src_port); + } + + void on_receive(uintptr_t ptr, + uint32_t len, + uint32_t src_ip, + uint16_t src_port) + { + this->remoteIP = src_ip; + this->remotePort = src_port; + + uintptr_t copy = malloc(len); + if (!copy) { + free((void*)ptr, len); + return; + } + memcpy((void*)copy, (void*)ptr, len); + free((void*)ptr, len); + + int next = (r_tail + 1) % UDP_RING_CAP; + if (next == r_head) { + free((void*)ring[r_head].ptr, ring[r_head].size); + r_head = (r_head + 1) % UDP_RING_CAP; + } + + ring[r_tail] = { (uintptr_t)copy, len }; + src_ips[r_tail] = src_ip; + src_ports[r_tail] = src_port; + r_tail = next; + } + +public: + UDPSocket(uint8_t r, uint32_t pid_) : Socket(PROTO_UDP, r) { + this->pid = pid_; + this->role = r; + this->proto = PROTO_UDP; + if (this->role == SOCK_ROLE_CLIENT) { + int p = udp_alloc_ephemeral(pid_, dispatch); + if (p >= 0) { + s_by_port[p] = this; + this->localPort = p; + this->bound = true; + } + } + } + + ~UDPSocket() override { close(); } + + int32_t bind(uint16_t port) override { + if (role != SOCK_ROLE_SERVER) return SOCK_ERR_PERM; + if (bound) return SOCK_ERR_BOUND; + if (!udp_bind(port, pid, dispatch)) + return SOCK_ERR_SYS; + s_by_port[port] = this; + this->localPort = port; + this->bound = true; + return SOCK_OK; + } + + int64_t sendto(uint32_t ip, uint16_t port, const void* buf, uint64_t len) { + if (!bound) return SOCK_ERR_NOT_BOUND; + net_l4_endpoint src{ ipv4_get_cfg()->ip, localPort }; + net_l4_endpoint dst{ ip, port }; + sizedptr pay{ (uintptr_t)buf, (uint32_t)len }; + udp_send_segment(&src, &dst, pay); + this->remoteIP = ip; + this->remotePort = port; + return (int64_t)len; + } + + int64_t recvfrom(void* buf, uint64_t len, //b UDPSocket::recvfrom + uint32_t* src_ip, + uint16_t* src_port) + { + if (r_head == r_tail)return 0; + + auto p = ring[r_head]; + uint32_t ip = src_ips[r_head]; + uint16_t pt = src_ports[r_head]; + r_head = (r_head + 1) % UDP_RING_CAP; + + uint32_t tocpy = p.size < len ? p.size : (uint32_t)len; + memcpy(buf, (void*)p.ptr, tocpy); + + if (src_ip) *src_ip = ip; + if (src_port) *src_port = pt; + + this->remoteIP = ip; + this->remotePort = pt; + + free((void*)p.ptr, p.size); + return tocpy; + } + + int32_t close() override { + while (r_head != r_tail) { + free((void*)ring[r_head].ptr, + ring[r_head].size); + r_head = (r_head + 1) % UDP_RING_CAP; + } + udp_unbind(localPort, pid); + s_by_port[localPort] = nullptr; + bound = connected = false; + return Socket::close(); +} +}; + +UDPSocket* UDPSocket::s_by_port[MAX_PORTS] = { nullptr }; diff --git a/shared/net/transport_layer/tcp.c b/shared/net/transport_layer/tcp.c new file mode 100644 index 00000000..5bd86a31 --- /dev/null +++ b/shared/net/transport_layer/tcp.c @@ -0,0 +1,619 @@ +#include "tcp.h" +#include "types.h" +#include "networking/port_manager.h" +#include "net/internet_layer/ipv4.h" +#include "std/memfunctions.h" +#include "math/rng.h" +//TODO: add mtu check and fragmentation. also fragment rebuild +extern uintptr_t malloc(uint64_t size); +extern void free(void *ptr, uint64_t size); +extern void sleep(uint64_t ms); + +static tcp_flow_t tcp_flows[MAX_TCP_FLOWS]; +static inline uint16_t htons(uint16_t x) { + return (uint16_t)((x << 8) | (x >> 8)); +} +static inline uint16_t ntohs(uint16_t x) { + return htons(x); +} + +static inline uint32_t htonl(uint32_t x) { + return ((x & 0x000000FFU) << 24) | + ((x & 0x0000FF00U) << 8) | + ((x & 0x00FF0000U) >> 8) | + ((x & 0xFF000000U) >> 24); +} +static inline uint32_t ntohl(uint32_t x) { + return htonl(x); +} +tcp_data* tcp_get_ctx(uint16_t local_port, + uint32_t remote_ip, + uint16_t remote_port) +{ + int idx = find_flow(local_port, remote_ip, remote_port); + if (idx < 0) + return NULL; + return &tcp_flows[idx].ctx; +} +static uint32_t checksum_add(uint32_t sum, uint16_t val) { + sum += val; + if (sum > 0xFFFF) { + sum = (sum & 0xFFFF) + 1; + } + return sum; +} +uint16_t tcp_compute_checksum(const void *segment, + uint16_t seg_len, + uint32_t src_ip, + uint32_t dst_ip) +{ + const uint8_t *seg = (const uint8_t *)segment; + const uint64_t total_len = 12 + seg_len; + + uintptr_t raw = malloc(total_len); + if (!raw) { + return 0; + } + uint8_t *buf = (uint8_t *)raw; + + buf[0] = (src_ip >> 24) & 0xFF; + buf[1] = (src_ip >> 16) & 0xFF; + buf[2] = (src_ip >> 8) & 0xFF; + buf[3] = (src_ip >> 0) & 0xFF; + buf[4] = (dst_ip >> 24) & 0xFF; + buf[5] = (dst_ip >> 16) & 0xFF; + buf[6] = (dst_ip >> 8) & 0xFF; + buf[7] = (dst_ip >> 0) & 0xFF; + buf[8] = 0; + buf[9] = 6; + buf[10] = (seg_len >> 8) & 0xFF; + buf[11] = (seg_len >> 0) & 0xFF; + memcpy(buf + 12, seg, seg_len); + buf[12 + 16] = 0; + buf[12 + 17] = 0; + + uint32_t sum = 0; + for (uint64_t i = 0; i + 1 < total_len; i += 2) { + uint16_t word = (uint16_t)buf[i] << 8 | buf[i + 1]; + sum = checksum_add(sum, word); + } + + if (total_len & 1) { + uint16_t word = (uint16_t)buf[total_len - 1] << 8; + sum = checksum_add(sum, word); + } + + free((void *)raw, total_len); + + return htons((uint16_t)(~sum & 0xFFFF)); +} + +static int find_flow(uint16_t local_port, uint32_t remote_ip, uint16_t remote_port) { + for (int i = 0; i < MAX_TCP_FLOWS; ++i) { + tcp_flow_t *f = &tcp_flows[i]; + if (f->state != TCP_STATE_CLOSED) { + if (f->local_port == local_port) { + if (f->state == TCP_LISTEN) { + if (remote_ip == 0 && remote_port == 0) { + return i; + } + } + if (f->remote.ip == remote_ip && f->remote.port == remote_port) { + return i; + } + } + } + } + return -1; +} + +static int allocate_flow_entry() { + for (int i = 0; i < MAX_TCP_FLOWS; ++i) { + if (tcp_flows[i].state == TCP_STATE_CLOSED) { + tcp_flows[i].state = TCP_STATE_CLOSED; + tcp_flows[i].retries = 0; + return i; + } + } + return -1; +} + +static void free_flow_entry(int idx) { + if (idx >= 0 && idx < MAX_TCP_FLOWS) { + tcp_flows[idx].state = TCP_STATE_CLOSED; + tcp_flows[idx].local_port = 0; + tcp_flows[idx].remote.ip = 0; + tcp_flows[idx].remote.port = 0; + tcp_flows[idx].ctx.sequence = 0; + tcp_flows[idx].ctx.ack = 0; + tcp_flows[idx].ctx.flags = 0; + tcp_flows[idx].ctx.window = 0; + tcp_flows[idx].ctx.options.ptr = 0; + tcp_flows[idx].ctx.options.size = 0; + tcp_flows[idx].ctx.payload.ptr = 0; + tcp_flows[idx].ctx.payload.size = 0; + tcp_flows[idx].ctx.expected_ack = 0; + tcp_flows[idx].ctx.ack_received = 0; + tcp_flows[idx].retries = 0; + } +} + +static bool send_tcp_segment(uint32_t src_ip, uint32_t dst_ip, tcp_hdr_t *hdr, const uint8_t *payload, uint16_t payload_len) { + uint8_t header_words = sizeof(tcp_hdr_t) / 4; + hdr->data_offset_reserved = (header_words << 4) | 0x0; + hdr->window = htons(hdr->window); + uint16_t tcp_len = sizeof(tcp_hdr_t) + payload_len; + uint8_t *segment = (uint8_t*) malloc(tcp_len); + if (!segment) { + return false; + } + memcpy(segment, hdr, sizeof(tcp_hdr_t)); + if (payload_len > 0) { + memcpy(segment + sizeof(tcp_hdr_t), payload, payload_len); + } + tcp_hdr_t *hdr_on_buf = (tcp_hdr_t*) segment; + hdr_on_buf->checksum = 0; + uint16_t csum = tcp_compute_checksum(segment, tcp_len, src_ip, dst_ip); + hdr_on_buf->checksum = csum; + ipv4_send_segment(src_ip, dst_ip, 6, (sizedptr){ .ptr = (uintptr_t)segment, .size = tcp_len }); + free(segment, tcp_len); + return true; +} + +static void send_reset(uint32_t src_ip, uint32_t dst_ip, + uint16_t src_port, uint16_t dst_port, + uint32_t seq, uint32_t ack, bool ack_valid) { + tcp_hdr_t rst_hdr; + rst_hdr.src_port = htons(src_port); + rst_hdr.dst_port = htons(dst_port); + if (ack_valid) { + rst_hdr.sequence = htonl(0); + rst_hdr.ack = htonl(seq + 1); + rst_hdr.flags = (1 << RST_F) | (1 << ACK_F); + } else { + rst_hdr.sequence = htonl(ack); + rst_hdr.ack = htonl(0); + rst_hdr.flags = (1 << RST_F); + } + rst_hdr.window = 0; + rst_hdr.urgent_ptr = 0; + + send_tcp_segment(src_ip, dst_ip, &rst_hdr, NULL, 0); +} + +bool tcp_bind(uint16_t port, uint16_t pid, port_recv_handler_t handler) { + if (!port_bind_manual(PROTO_TCP, port, pid, handler)) { + return false; + } + + int idx = allocate_flow_entry(); + if (idx >= 0) { + tcp_flows[idx].local_port = port; + tcp_flows[idx].remote.ip = 0; + tcp_flows[idx].remote.port = 0; + tcp_flows[idx].state = TCP_LISTEN; + tcp_flows[idx].ctx.sequence = 0; + tcp_flows[idx].ctx.ack = 0; + tcp_flows[idx].ctx.flags = 0; + tcp_flows[idx].ctx.window = 0xFFFF; + tcp_flows[idx].ctx.options.ptr = 0; + tcp_flows[idx].ctx.options.size = 0; + tcp_flows[idx].ctx.payload.ptr = 0; + tcp_flows[idx].ctx.payload.size = 0; + tcp_flows[idx].ctx.expected_ack = 0; + tcp_flows[idx].ctx.ack_received = 0; + tcp_flows[idx].retries = 0; + } + return true; +} + +int tcp_alloc_ephemeral(uint16_t pid, port_recv_handler_t handler) { + int port = port_alloc_ephemeral(PROTO_TCP, pid, handler); + return port; +} + +bool tcp_unbind(uint16_t port, uint16_t pid) { + bool res = port_unbind(PROTO_TCP, port, pid); + if (res) { + for (int i = 0; i < MAX_TCP_FLOWS; ++i) { + if (tcp_flows[i].local_port == port) { + free_flow_entry(i); + } + } + } + return res; +} + +bool tcp_handshake(uint16_t local_port, net_l4_endpoint *dst, tcp_data *flow_ctx, uint16_t pid) { + int idx = allocate_flow_entry(); + if (idx < 0) { + return false; + } + tcp_flow_t *flow = &tcp_flows[idx]; + flow->local_port = local_port; + flow->remote.ip = dst->ip; + flow->remote.port = dst->port; + flow->state = TCP_SYN_SENT; + flow->retries = TCP_SYN_RETRIES; + uint32_t iss = 1; + flow->ctx.sequence = iss; + flow->ctx.ack = 0; + flow->ctx.window = 0xFFFF; + flow->ctx.options.ptr = 0; + flow->ctx.options.size = 0; + flow->ctx.payload.ptr = 0; + flow->ctx.payload.size = 0; + flow->ctx.flags = (1 << SYN_F); + flow->ctx.expected_ack = iss + 1; + flow->ctx.ack_received = 0; + tcp_hdr_t syn_hdr; + syn_hdr.src_port = htons(local_port); + syn_hdr.dst_port = htons(dst->port); + syn_hdr.sequence = htonl(flow->ctx.sequence); + syn_hdr.ack = htonl(0); + syn_hdr.flags = (1 << SYN_F); + syn_hdr.window = flow->ctx.window; + syn_hdr.urgent_ptr = 0; + uint32_t src_ip = ipv4_get_cfg()->ip; + bool sent = false; + while (flow->retries-- > 0) { + sent = send_tcp_segment(src_ip, dst->ip, &syn_hdr, NULL, 0); + if (!sent) { + break; + } + uint64_t wait_ms = TCP_RETRY_TIMEOUT_MS; + uint64_t elapsed = 0; + const uint64_t interval = 50; + while (elapsed < wait_ms) { + if (flow->state == TCP_ESTABLISHED) { + *flow_ctx = flow->ctx; + return true; + } + if (flow->state == TCP_STATE_CLOSED) { + free_flow_entry(idx); + return false; + } + sleep(interval); + elapsed += interval; + } + } + free_flow_entry(idx); + return false; +} + +tcp_result_t tcp_flow_send(tcp_data *flow_ctx) { + if (!flow_ctx) { + return TCP_INVALID; + } + tcp_flow_t *flow = NULL; + for (int i = 0; i < MAX_TCP_FLOWS; ++i) { + if (&tcp_flows[i].ctx == flow_ctx) { + flow = &tcp_flows[i]; + break; + } + } + if (!flow) { + return TCP_INVALID; + } + + uint8_t flags = flow_ctx->flags; + uint8_t *payload_ptr = (uint8_t*) flow_ctx->payload.ptr; + uint16_t payload_len = flow_ctx->payload.size; + if (flow->state != TCP_ESTABLISHED && !(flags & (1<state == TCP_CLOSE_WAIT && (flags & (1<local_port); + hdr.dst_port = htons(flow->remote.port); + hdr.sequence = htonl(flow_ctx->sequence); + hdr.ack = htonl(flow_ctx->ack); + hdr.flags = flags; + hdr.window = flow_ctx->window ? flow_ctx->window : 0xFFFF; + hdr.urgent_ptr = 0; + + uint32_t src_ip = ipv4_get_cfg()->ip; + uint32_t dst_ip = flow->remote.ip; + + bool sent = send_tcp_segment(src_ip, dst_ip, &hdr, payload_ptr, payload_len); + if (!sent) { + return TCP_RESET; + } + + uint32_t seq_incr = payload_len; + if (flags & (1<sequence += seq_incr; + + if ((flags & (1< 0) { + flow_ctx->expected_ack = flow_ctx->sequence; + + int retries = TCP_DATA_RETRIES; + while (retries-- > 0) { + uint64_t wait_ms = TCP_RETRY_TIMEOUT_MS; + uint64_t elapsed = 0; + const uint64_t interval = 50; + while (elapsed < wait_ms) { + if (flow_ctx->ack_received >= flow_ctx->expected_ack) { + break; + } + if (flow->state == TCP_STATE_CLOSED) { + return TCP_RESET; + } + sleep(interval); + elapsed += interval; + } + if (flow_ctx->ack_received >= flow_ctx->expected_ack) { + break; + } + + if (flow->state >= TCP_CLOSING || flow->state == TCP_STATE_CLOSED) { + break; + } + + flow_ctx->sequence -= seq_incr; + send_tcp_segment(src_ip, dst_ip, &hdr, payload_ptr, payload_len); + flow_ctx->sequence += seq_incr; + } + if (flow_ctx->ack_received < flow_ctx->expected_ack) { + return TCP_TIMEOUT; + } + } + return TCP_OK; +} + +tcp_result_t tcp_flow_close(tcp_data *flow_ctx) { + if (!flow_ctx) return TCP_INVALID; + + tcp_flow_t *flow = NULL; + for (int i = 0; i < MAX_TCP_FLOWS; ++i) { + if (&tcp_flows[i].ctx == flow_ctx) { flow = &tcp_flows[i]; break; } + } + if (!flow) return TCP_INVALID; + + if (flow->state == TCP_ESTABLISHED || flow->state == TCP_CLOSE_WAIT) { + + flow_ctx->payload.ptr = 0; + flow_ctx->payload.size = 0; + flow_ctx->flags = (1u << FIN_F) | (1u << ACK_F); + + tcp_result_t res = tcp_flow_send(flow_ctx); + if (res != TCP_OK) return res; + + if (flow->state == TCP_ESTABLISHED) { + flow->state = TCP_FIN_WAIT_1; + } else { flow->state = TCP_LAST_ACK; } + + const uint64_t max_wait = 2000; + const uint64_t interval = 100; + uint64_t elapsed = 0; + while (elapsed < max_wait) { + if (flow->state == TCP_STATE_CLOSED) break; + sleep(interval); + elapsed += interval; + } + + int idx = (int)(flow - tcp_flows); + free_flow_entry(idx); + return TCP_OK; + } + + return TCP_INVALID; +} + + +void tcp_input(uintptr_t ptr, uint32_t len, uint32_t src_ip, uint32_t dst_ip) { + if (len < sizeof(tcp_hdr_t)) { + return; + } + tcp_hdr_t *hdr = (tcp_hdr_t*) ptr; + + uint16_t recv_checksum = hdr->checksum; + hdr->checksum = 0; + uint16_t calc_checksum = tcp_compute_checksum((uint8_t*)hdr, (uint16_t)len, src_ip, dst_ip); + hdr->checksum = recv_checksum; + if (recv_checksum != calc_checksum) return; + + uint16_t src_port = ntohs(hdr->src_port); + uint16_t dst_port = ntohs(hdr->dst_port); + uint32_t seq = ntohl(hdr->sequence); + uint32_t ack = ntohl(hdr->ack); + uint8_t flags = hdr->flags; + uint16_t window = ntohs(hdr->window); + int idx = find_flow(dst_port, src_ip, src_port); + tcp_flow_t *flow = (idx >= 0 ? &tcp_flows[idx] : NULL); + + if (!flow) { + int listen_idx = find_flow(dst_port, 0, 0); + if ((flags & (1<= 0) { + //TODO: use a syscall for the rng + rng_t rng; + rng_init_random(&rng); + tcp_flow_t *lf = &tcp_flows[listen_idx]; + int new_idx = allocate_flow_entry(); + if (new_idx < 0) return; + + flow = &tcp_flows[new_idx]; + flow->local_port = dst_port; + flow->remote.ip = src_ip; + flow->remote.port = src_port; + flow->state = TCP_SYN_RECEIVED; + flow->retries = TCP_SYN_RETRIES; + + uint32_t iss = rng_next32(&rng); + flow->ctx.sequence = iss; + flow->ctx.ack = seq + 1; + flow->ctx.window = 0xFFFF; + flow->ctx.flags = 0; + flow->ctx.options.ptr = 0; + flow->ctx.options.size = 0; + flow->ctx.payload.ptr = 0; + flow->ctx.payload.size = 0; + flow->ctx.expected_ack = iss + 1; + flow->ctx.ack_received = 0; + + tcp_hdr_t synack_hdr; + synack_hdr.src_port = htons(dst_port); + synack_hdr.dst_port = htons(src_port); + synack_hdr.sequence = htonl(iss); + synack_hdr.ack = htonl(seq + 1); + synack_hdr.flags = (1<ctx.window; + synack_hdr.urgent_ptr = 0; + uint32_t src_ip_local = ipv4_get_cfg()->ip; + send_tcp_segment(src_ip_local, src_ip, &synack_hdr, NULL, 0); + + return; + } else { + if (!(flags & (1<state) { + case TCP_SYN_SENT: + if ((flags & (1<ctx.expected_ack) { + + flow->ctx.ack = seq + 1; + flow->ctx.ack_received = ack; + flow->ctx.sequence += 1; + + tcp_hdr_t final_ack; + final_ack.src_port = htons(flow->local_port); + final_ack.dst_port = htons(flow->remote.port); + final_ack.sequence = htonl(flow->ctx.sequence + 1); + final_ack.ack = htonl(flow->ctx.ack); + final_ack.flags = (1<ctx.window; + final_ack.urgent_ptr = 0; + uint32_t src_ip_local = ipv4_get_cfg()->ip; + send_tcp_segment(src_ip_local, flow->remote.ip, &final_ack, NULL, 0); + + flow->state = TCP_ESTABLISHED; + } + } else if (flags & (1<state = TCP_STATE_CLOSED; + } + return; + case TCP_SYN_RECEIVED: + if ((flags & (1<ctx.expected_ack) { + flow->ctx.sequence += 1; + flow->state = TCP_ESTABLISHED; + flow->ctx.ack_received = ack; + port_recv_handler_t h = port_get_handler(PROTO_TCP, dst_port); + if (h) { + h(0, 0, src_ip, src_port, dst_port); + } + } + } else if (flags & (1<data_offset_reserved >> 4) * 4; + uint32_t data_len = len - hdr_len; + + if ((flags & ACK_F) + && !(flags & (SYN_F|FIN_F|RST_F)) + && data_len == 0 + && seq + 1 == flow->ctx.ack) + { + tcp_hdr_t ack_hdr = { + .src_port = htons(flow->local_port), + .dst_port = htons(flow->remote.port), + .sequence = htonl(flow->ctx.sequence), + .ack = htonl(flow->ctx.ack), + .flags = (1 << ACK_F), + .window = flow->ctx.window, + .urgent_ptr = 0 + }; + send_tcp_segment( + ipv4_get_cfg()->ip, + flow->remote.ip, + &ack_hdr, NULL, 0 + ); + return; + } + case TCP_FIN_WAIT_1: + case TCP_FIN_WAIT_2: + case TCP_CLOSE_WAIT: + case TCP_CLOSING: + case TCP_LAST_ACK: { + if (flags & (1 << RST_F)) { + free_flow_entry(idx); + return; + } + uint8_t hdr_len =(hdr->data_offset_reserved >> 4) * 4; + if (len < hdr_len) + return; + uint32_t data_len = len - hdr_len; + bool fin_set = (flags & (1 << FIN_F)) != 0; + bool fin_inseq = fin_set && (seq == flow->ctx.ack); + + if (data_len && seq == flow->ctx.ack) { + flow->ctx.ack += data_len; + port_recv_handler_t h = port_get_handler(PROTO_TCP, dst_port); + if (h) h(ptr + hdr_len, data_len, src_ip, src_port, dst_port); + } + + if (fin_inseq) { + flow->ctx.ack += 1; + } + + if ((data_len && seq == flow->ctx.ack - data_len) || fin_inseq) { + tcp_hdr_t ackhdr = { + .src_port = htons(flow->local_port), + .dst_port = htons(flow->remote.port), + .sequence = htonl(flow->ctx.sequence), + .ack = htonl(flow->ctx.ack), + .flags = (1 << ACK_F), + .window = flow->ctx.window, + .urgent_ptr = 0 + }; + send_tcp_segment(ipv4_get_cfg()->ip, + flow->remote.ip, &ackhdr, NULL, 0); + } + + if ((flags & (1 << ACK_F)) && + ack > flow->ctx.ack_received) { + flow->ctx.ack_received = ack; + + if (flow->state == TCP_FIN_WAIT_1 && + ack == flow->ctx.expected_ack) { + flow->state = TCP_FIN_WAIT_2; + } else if ((flow->state == TCP_LAST_ACK || + flow->state == TCP_CLOSING) && + ack == flow->ctx.expected_ack) { + free_flow_entry(idx); + return; + } + } + + if (fin_inseq) { + if (flow->state == TCP_ESTABLISHED) + flow->state = TCP_CLOSE_WAIT; + else if (flow->state == TCP_FIN_WAIT_1) + flow->state = TCP_CLOSING; + else if (flow->state == TCP_FIN_WAIT_2) + flow->state = TCP_TIME_WAIT; + else if (flow->state == TCP_CLOSING) + flow->state = TCP_TIME_WAIT; + else if (flow->state == TCP_LAST_ACK) + flow->state = TCP_TIME_WAIT; + } + return; +} + default: + break; + } +} diff --git a/shared/net/transport_layer/tcp.h b/shared/net/transport_layer/tcp.h new file mode 100644 index 00000000..4a6f20c2 --- /dev/null +++ b/shared/net/transport_layer/tcp.h @@ -0,0 +1,116 @@ +#pragma once + +#include "networking/port_manager.h" +#include "net/internet_layer/ipv4.h" +#include "net/link_layer/eth.h" +#include "std/memfunctions.h" +#include "net/network_types.h" +#ifdef __cplusplus +extern "C" { +#endif + +#define FIN_F 0 +#define SYN_F 1 +#define RST_F 2 +#define PSH_F 3 +#define ACK_F 4 +#define URG_F 5 +#define ECE_F 6 +#define CWR_F 7 + +typedef enum { + TCP_OK = 0, + TCP_RETRY = 1, + TCP_RESET = 2, + TCP_TIMEOUT = -2, + TCP_CSUM_ERR = -3, + TCP_INVALID = -4, + TCP_WOULDBLOCK = -5, + TCP_DISCONNECT = -6, + TCP_UNIMPLEMENT = -10, + TCP_BUSY = -11, +} tcp_result_t; + +typedef struct __attribute__((packed)) { + uint16_t src_port; + uint16_t dst_port; + uint32_t sequence; + uint32_t ack; + uint8_t data_offset_reserved; + uint8_t flags; + uint16_t window; + uint16_t checksum; + uint16_t urgent_ptr; +} tcp_hdr_t; + +typedef struct { + uint32_t sequence; + uint32_t ack; + uint8_t flags; + uint16_t window; + sizedptr options; + sizedptr payload; + uint32_t expected_ack; + uint32_t ack_received; +} tcp_data; + +typedef enum { + TCP_STATE_CLOSED = 0, + TCP_LISTEN, + TCP_SYN_SENT, + TCP_SYN_RECEIVED, + TCP_ESTABLISHED, + TCP_FIN_WAIT_1, + TCP_FIN_WAIT_2, + TCP_CLOSE_WAIT, + TCP_CLOSING, + TCP_LAST_ACK, + TCP_TIME_WAIT +} tcp_state_t; + +typedef struct { + uint16_t local_port; + net_l4_endpoint remote; + tcp_state_t state; + tcp_data ctx; + uint8_t retries; +} tcp_flow_t; + +#define MAX_TCP_FLOWS 512 +#define TCP_SYN_RETRIES 5 +#define TCP_DATA_RETRIES 5 +#define TCP_RETRY_TIMEOUT_MS 1000 + +static int find_flow(uint16_t local_port, uint32_t remote_ip, uint16_t remote_port); + +tcp_data* tcp_get_ctx(uint16_t local_port, + uint32_t remote_ip, + uint16_t remote_port); + +bool tcp_bind(uint16_t port, + uint16_t pid, + port_recv_handler_t handler); + +int tcp_alloc_ephemeral(uint16_t pid, + port_recv_handler_t handler); + +bool tcp_unbind(uint16_t port, + uint16_t pid); + +bool tcp_handshake(uint16_t local_port, + net_l4_endpoint *dst, + tcp_data *flow_ctx, + uint16_t pid); + +tcp_result_t tcp_flow_send(tcp_data *flow_ctx); + +tcp_result_t tcp_flow_close(tcp_data *flow_ctx); + +void tcp_input(uintptr_t ptr, + uint32_t len, + uint32_t src_ip, + uint32_t dst_ip); + +#ifdef __cplusplus +} +#endif diff --git a/shared/net/transport_layer/udp.c b/shared/net/transport_layer/udp.c new file mode 100644 index 00000000..7a478b6e --- /dev/null +++ b/shared/net/transport_layer/udp.c @@ -0,0 +1,114 @@ +#include "udp.h" +#include "net/checksums.h" +#include "net/internet_layer/ipv4.h" +#include "networking/port_manager.h" +#include "std/memfunctions.h" +#include "types.h" + +extern void sleep(uint64_t ms); +extern uintptr_t malloc(uint64_t size); +extern void free(void *ptr, uint64_t size); + +static inline uint16_t bswap16(uint16_t v) { return __builtin_bswap16(v); } + +bool udp_bind(uint16_t port, + uint16_t pid, + port_recv_handler_t handler) +{ + return port_bind_manual(PROTO_UDP, port, pid, handler); +} + +int udp_alloc_ephemeral(uint16_t pid, + port_recv_handler_t handler) +{ + return port_alloc_ephemeral(PROTO_UDP, pid, handler); +} + +bool udp_unbind(uint16_t port, + uint16_t pid) +{ + return port_unbind(PROTO_UDP, port, pid); +} + + +size_t create_udp_segment(uintptr_t buf, + const net_l4_endpoint *src, + const net_l4_endpoint *dst, + sizedptr payload) +{ + udp_hdr_t *udp = (udp_hdr_t*)buf; + udp->src_port = bswap16(src->port); + udp->dst_port = bswap16(dst->port); + uint16_t full_len = sizeof(*udp) + payload.size; + udp->length= bswap16(full_len); + udp->checksum = 0; + + memcpy((void*)(buf + sizeof(*udp)), (void*)payload.ptr, payload.size); + + uint16_t csum = checksum16_pipv4( src->ip, dst->ip, 0x11, (const uint8_t*)udp, full_len); + udp->checksum = bswap16(csum); + return full_len; +} + +void udp_send_segment(const net_l4_endpoint *src, + const net_l4_endpoint *dst, + sizedptr payload) +{ + uint32_t udp_max = sizeof(udp_hdr_t) + payload.size; + uint32_t ip_max = sizeof(ipv4_hdr_t) + udp_max; + uint32_t eth_total = sizeof(eth_hdr_t) + ip_max; + + uintptr_t buf = (uintptr_t)malloc(eth_total); + if (!buf) return; + + uintptr_t udp_buf = buf + sizeof(eth_hdr_t) + sizeof(ipv4_hdr_t); + size_t udp_len = create_udp_segment(udp_buf, src, dst, payload); + + ipv4_send_segment(src->ip, dst->ip, 0x11,(sizedptr){ udp_buf, (uint32_t)udp_len }); + + free((void*)buf, eth_total); +} + +sizedptr udp_strip_header(uintptr_t ptr, uint32_t len) { + if (len < sizeof(udp_hdr_t)) { + return (sizedptr){0,0}; + } + udp_hdr_t *hdr = (udp_hdr_t*)ptr; + uint16_t total = bswap16(hdr->length); + if (total < sizeof(udp_hdr_t) || total > len) { + return (sizedptr){0,0}; + } + return (sizedptr){ + .ptr = ptr + sizeof(udp_hdr_t), + .size = total - sizeof(udp_hdr_t) + }; +} + +void udp_input(uintptr_t ptr, + uint32_t len, + uint32_t src_ip, + uint32_t dst_ip) +{ + sizedptr pl = udp_strip_header(ptr, len); + if (!pl.ptr) return; + + udp_hdr_t *hdr = (udp_hdr_t*)ptr; + if (hdr->checksum) { + uint16_t recv = hdr->checksum; + hdr->checksum = 0; + uint16_t calc = checksum16_pipv4( + src_ip, dst_ip, 0x11, + (const uint8_t*)hdr, + pl.size + sizeof(*hdr) + ); + hdr->checksum = recv; + if (calc != bswap16(recv)) return; + } + + uint16_t dst_port = bswap16(hdr->dst_port); + uint16_t src_port = bswap16(hdr->src_port); + port_recv_handler_t handler = port_get_handler(PROTO_UDP, dst_port); + if (handler) { + handler(pl.ptr, pl.size, src_ip, src_port, dst_port); + } +} \ No newline at end of file diff --git a/shared/net/transport_layer/udp.h b/shared/net/transport_layer/udp.h new file mode 100644 index 00000000..eb5da1d7 --- /dev/null +++ b/shared/net/transport_layer/udp.h @@ -0,0 +1,43 @@ +#pragma once +#include "types.h" +#include "net/network_types.h" +#include "networking/port_manager.h" + +#ifdef __cplusplus +extern "C" { +#endif + +typedef struct __attribute__((packed)) { + uint16_t src_port; + uint16_t dst_port; + uint16_t length; + uint16_t checksum; +} udp_hdr_t; + +size_t create_udp_segment(uintptr_t buf, + const net_l4_endpoint *src, + const net_l4_endpoint *dst, + sizedptr payload); + + +void udp_send_segment(const net_l4_endpoint *src, + const net_l4_endpoint *dst, + sizedptr payload); + +void udp_input(uintptr_t ptr, + uint32_t len, + uint32_t src_ip, + uint32_t dst_ip); + +bool udp_bind(uint16_t port, + uint16_t pid, + port_recv_handler_t handler); + +int udp_alloc_ephemeral(uint16_t pid, port_recv_handler_t handler); + +bool udp_unbind(uint16_t port, + uint16_t pid); + +#ifdef __cplusplus +} +#endif diff --git a/shared/net/udp.c b/shared/net/udp.c deleted file mode 100644 index 9ab81d8e..00000000 --- a/shared/net/udp.c +++ /dev/null @@ -1,53 +0,0 @@ -#include "udp.h" -#include "console/kio.h" -#include "net/network_types.h" -#include "syscalls/syscalls.h" -#include "eth.h" -#include "ipv4.h" - -void create_udp_packet(uintptr_t p, network_connection_ctx source, network_connection_ctx destination, sizedptr payload){ - p = create_eth_packet(p, source.mac, destination.mac, 0x800); - - p = create_ipv4_packet(p, sizeof(udp_hdr_t) + payload.size, 0x11, source.ip, destination.ip); - - udp_hdr_t* udp = (udp_hdr_t*)p; - udp->src_port = __builtin_bswap16(source.port); - udp->dst_port = __builtin_bswap16(destination.port); - udp->length = __builtin_bswap16(sizeof(udp_hdr_t) + payload.size); - - p += sizeof(udp_hdr_t); - - uint8_t* data = (uint8_t*)p; - uint8_t* payload_c = (uint8_t*)payload.ptr; - for (size_t i = 0; i < payload.size; i++) data[i] = payload_c[i]; - - udp->checksum = __builtin_bswap16(checksum16_pipv4(source.ip,destination.ip,0x11,(uint8_t*)udp,sizeof(udp_hdr_t) + payload.size)); - -} - -uint16_t udp_parse_packet(uintptr_t ptr){ - udp_hdr_t* udp = (udp_hdr_t*)ptr; - ptr += sizeof(udp_hdr_t); - uint16_t port = __builtin_bswap16(udp->dst_port); - return port; -} - -sizedptr udp_parse_packet_payload(uintptr_t ptr){ - eth_hdr_t* eth = (eth_hdr_t*)ptr; - - ptr += sizeof(eth_hdr_t); - - if (__builtin_bswap16(eth->ethertype) == 0x800){ - ipv4_hdr_t* ip = (ipv4_hdr_t*)ptr; - uint32_t srcip = __builtin_bswap32(ip->src_ip); - ptr += sizeof(ipv4_hdr_t); - if (ip->protocol == 0x11){ - udp_hdr_t* udp = (udp_hdr_t*)ptr; - ptr += sizeof(udp_hdr_t); - uint16_t payload_len = __builtin_bswap16(udp->length) - sizeof(udp_hdr_t); - return (sizedptr){ptr,payload_len}; - } - } - - return (sizedptr){0,0}; -} \ No newline at end of file diff --git a/shared/net/udp.h b/shared/net/udp.h deleted file mode 100644 index c14ab36a..00000000 --- a/shared/net/udp.h +++ /dev/null @@ -1,23 +0,0 @@ -#pragma once - -#ifdef __cplusplus -extern "C" { -#endif - -#include "types.h" -#include "net/network_types.h" - -typedef struct __attribute__((packed)) udp_hdr_t { - uint16_t src_port; - uint16_t dst_port; - uint16_t length; - uint16_t checksum; -} udp_hdr_t; - -void create_udp_packet(uintptr_t p, network_connection_ctx source, network_connection_ctx destination, sizedptr payload); -uint16_t udp_parse_packet(uintptr_t ptr); -sizedptr udp_parse_packet_payload(uintptr_t ptr); - -#ifdef __cplusplus -} -#endif \ No newline at end of file diff --git a/shared/std/string.c b/shared/std/string.c index 54f3e682..f7f572f2 100644 --- a/shared/std/string.c +++ b/shared/std/string.c @@ -329,3 +329,52 @@ uint64_t parse_hex_u64(char* str, size_t size){ } return result; } + + +string string_from_const(const char *lit) +{ + uint32_t len = strlen(lit, 0); + return (string){ (char *)lit, len, len }; +} + +string string_concat(string a, string b) +{ + uint32_t len = a.length + b.length; + char *dst = (char *)malloc(len); + if (!dst) return (string){0}; + memcpy(dst, a.data, a.length); + memcpy(dst + a.length, b.data, b.length); + return (string){ dst, len, len }; +} + +void string_concat_inplace(string *dest, string src) //b string_concat_inplace +{ + if (!dest || !src.data) return; + + uint32_t new_len = dest->length + src.length; + uint32_t new_cap = new_len + 1; + + uintptr_t raw = malloc(new_cap); + if (!raw) return; + char *dst = (char *)raw; + + if (dest->data && dest->length) { + memcpy(dst, dest->data, dest->length); + } + + memcpy(dst + dest->length, src.data, src.length); + dst[new_len] = '\0'; + if (dest->data) { + free(dest->data, dest->mem_length); + } + dest->data = dst; + dest->length = new_len; + dest->mem_length = new_cap; +} + +void string_append_bytes(string *dest, const void *buf, uint32_t len) +{ + if (!len) return; + string tmp = { (char *)buf, len, len }; + string_concat_inplace(dest, tmp); +} \ No newline at end of file diff --git a/shared/std/string.h b/shared/std/string.h index 1a7d3641..8da8b39e 100644 --- a/shared/std/string.h +++ b/shared/std/string.h @@ -42,6 +42,11 @@ uint64_t parse_hex_u64(char* str, size_t size); bool utf16tochar( uint16_t* str_in, char* out_str, size_t max_len); +string string_from_const(const char *literal); +string string_concat(string a, string b); +void string_concat_inplace(string *dest, string src); +void string_append_bytes(string *dest, const void *buf, uint32_t len); + #ifdef __cplusplus } #endif \ No newline at end of file diff --git a/shared/syscalls/syscalls.h b/shared/syscalls/syscalls.h index 85cdcbfd..673a8ea5 100644 --- a/shared/syscalls/syscalls.h +++ b/shared/syscalls/syscalls.h @@ -34,10 +34,12 @@ extern void draw_primitive_string(string *text, gpu_point *p, uint32_t scale, ui extern uint64_t get_time(); -extern bool bind_port(uint16_t port); -extern bool unbind_port(uint16_t port); -extern void send_packet(NetProtocol protocol, uint16_t port, network_connection_ctx *destination, void* payload, uint16_t payload_len); -extern bool read_packet(sizedptr *ptr); +extern bool network_bind_port_current(uint16_t port); +extern bool network_unbind_port_current(uint16_t port); +extern int network_alloc_ephemeral_port_current(); +extern int net_tx_frame(uintptr_t frame_ptr, uint32_t frame_len); +extern int net_rx_frame(sizedptr *out_frame); +extern bool dispatch_enqueue_frame(const sizedptr *frame); void printf(const char *fmt, ...);