diff --git a/.github/workflows/test-darwin-arm64.yaml b/.github/workflows/test-darwin-arm64.yaml deleted file mode 100644 index 9bd620f..0000000 --- a/.github/workflows/test-darwin-arm64.yaml +++ /dev/null @@ -1,36 +0,0 @@ -name: Test - -on: - push: - branches: - - main - pull_request: - -jobs: - test: - name: Test Darwin (go ${{ matrix.go }}, ${{ matrix.host }}) - runs-on: ${{ matrix.host }} - timeout-minutes: 10 - strategy: - matrix: - go: ['1.25', '1.26'] - host: ['macos-latest'] - - steps: - - name: Checkout code - uses: actions/checkout@v4 - - - name: Setup Go - uses: actions/setup-go@v5 - with: - go-version: ${{ matrix.go }} - cache: true - - # Function cloning is all that works on darwin/arm64 right now. - - name: Run tests (buildmode=exe) - run: | - go test -buildmode=exe -run 'TestCloneFunc' - - - name: Run tests (buildmode=pie) - run: | - go test -buildmode=pie -run 'TestCloneFunc' diff --git a/.github/workflows/test.yaml b/.github/workflows/test.yaml index 67eb70e..ee01c8b 100644 --- a/.github/workflows/test.yaml +++ b/.github/workflows/test.yaml @@ -14,7 +14,12 @@ jobs: strategy: matrix: go: ['1.25', '1.26'] - host: [ubuntu-latest, macos-15-intel, windows-latest, ubuntu-24.04-arm] + host: + - ubuntu-latest + - macos-15-intel + - windows-latest + - ubuntu-24.04-arm + - macos-latest steps: - name: Checkout code diff --git a/README.md b/README.md index 94c30ac..85bad09 100644 --- a/README.md +++ b/README.md @@ -2,7 +2,7 @@ [![Go Reference](https://pkg.go.dev/badge/github.com/pboyd/redefine.svg)](https://pkg.go.dev/github.com/pboyd/redefine) -Highly experimental package to redefine Go functions at runtime as some interpreted languages allow (Ruby, Perl, etc.). I wrote about how this works and some of the limitations [here](https://pboyd.io/posts/redefining-go-functions/). This is a fun experiment, but do not use it for production code. +Highly experimental package to redefine Go functions at runtime as some interpreted languages allow (Ruby, Perl, etc.). I wrote about how this works and some of the limitations [here](https://pboyd.io/posts/redefining-go-functions/), and about Darwin / Mac OS support in particular [here](https://pboyd.io/posts/redefining-go-functions-on-darwin-arm64/). This is a fun experiment, but do not use it for production code. ```go package main @@ -38,10 +38,10 @@ It's 5:00 PM somewhere | Darwin (macOS) | amd64 | Full | | | Linux | arm64 | Full | | | Windows | arm64 | Full | | +| Darwin (macOS) | arm64 | Full | | | FreeBSD | amd64 | Untested | Compiles but untested | | OpenBSD | amd64 | Untested | Compiles but untested | | NetBSD | amd64 | Untested | Compiles but untested | -| Darwin (macOS) | arm64 | Broken | `mprotect` returns EACCES | ## FAQ diff --git a/clone.go b/clone.go index b857b39..2b07672 100644 --- a/clone.go +++ b/clone.go @@ -11,6 +11,8 @@ import ( "unsafe" "github.com/pboyd/malloc" + "github.com/pboyd/redefine/internal/cacheflush" + "github.com/pboyd/redefine/internal/static" ) var errAddressOutOfRange = errors.New("address out of range") @@ -23,7 +25,7 @@ func cloneFunc[T any](fn T) (*clonedFunc[T], error) { return nil, fmt.Errorf("not a function, kind: %v", fnv.Kind()) } - originalCode, err := funcSlice(fn) + originalCode, err := static.GetInfo().FuncSlice(fn) if err != nil { return nil, err } @@ -64,7 +66,7 @@ func cloneFunc[T any](fn T) (*clonedFunc[T], error) { return nil, errors.New("failed to allocate memory for cloned function") } - cacheflush(newCode) + cacheflush.Flush(newCode) // This seems too complicated. The idea is to take our newly allocated // buffer of machine instructions and convince Go that it's really a @@ -128,16 +130,9 @@ func (a *allocator) init(startSize int) error { const absMinAddress = 0x100000 func initMallocBackend() (malloc.ArenaBackend, error) { - var text, etext uintptr - var end uintptr - pc, _, _, _ := runtime.Caller(0) - datap := findfunc(pc).datap - if datap != nil { - text = datap.text - etext = datap.etext - end = datap.end - } - if text == 0 || etext == 0 || end == 0 { + info := static.GetInfo() + text, etext := info.Text() + if text == 0 || etext == 0 || info.End == 0 { return nil, fmt.Errorf("failed to find moduledata") } @@ -149,7 +144,7 @@ func initMallocBackend() (malloc.ArenaBackend, error) { // // Use the size of the existing text segment so there's enough space to // clone every statically-linked function. - size := (etext - text + pageSize - 1) &^ (pageSize - 1) + size := etext - text // Cloned functions need to be near the existing text and data // segments so that they can be reached by the same @@ -163,8 +158,8 @@ func initMallocBackend() (malloc.ArenaBackend, error) { // If there's an ideal range for the architecture, try that first. if idealCloneDistance > 0 { // Search before text - minAddress := end - idealCloneDistance - if minAddress > end || minAddress < absMinAddress { + minAddress := info.End - idealCloneDistance + if minAddress > info.End || minAddress < absMinAddress { minAddress = absMinAddress } be := tryBackendRange(size, minAddress, text-pageSize-size) @@ -177,15 +172,15 @@ func initMallocBackend() (malloc.ArenaBackend, error) { if maxAddress < text { maxAddress = math.MaxUint } - be = tryBackendRange(size, end, maxAddress) + be = tryBackendRange(size, info.End, maxAddress) if be != nil { return be, nil } } // Nothing in the ideal range, so search within the acceptable range - minAddress := end - maxCloneDistance - if minAddress > end || minAddress < absMinAddress { + minAddress := info.End - maxCloneDistance + if minAddress > info.End || minAddress < absMinAddress { minAddress = absMinAddress } be := tryBackendRange(size, minAddress, text-pageSize-size) @@ -197,7 +192,7 @@ func initMallocBackend() (malloc.ArenaBackend, error) { if maxAddress < text { maxAddress = math.MaxUint } - be = tryBackendRange(size, end, maxAddress) + be = tryBackendRange(size, info.End, maxAddress) if be != nil { return be, nil } diff --git a/clone_mprotect_darwin.go b/clone_mprotect_darwin.go index c514670..7238147 100644 --- a/clone_mprotect_darwin.go +++ b/clone_mprotect_darwin.go @@ -4,31 +4,33 @@ package redefine /* #include +#include + +// jit_memcpy copies len bytes from src to dst within a JIT write-protected +// scope: pthread_jit_write_protect_np(0) before and (1) after, with an +// I-cache flush in between. The entire operation runs in C so that the return +// to Go code happens after MAP_JIT pages are back in execute mode. +static void jit_memcpy(void *dst, const void *src, size_t len) { + pthread_jit_write_protect_np(0); + memcpy(dst, src, len); + __builtin___clear_cache(dst, (char *)dst + len); + pthread_jit_write_protect_np(1); +} */ import "C" -import ( - "runtime" - - "golang.org/x/sys/unix" -) +import "unsafe" func mprotectHook(inner func(int) error) func(int) error { - return func(prot int) error { - // Instead of calling mprotect, just use Darwin's - // pthread_jit_write_protect_np which is effectively the same - // in this case. - - // This value is thread specific, so lock the running goroutine - // to the system thread. This assumes that this function is - // called in BeginMutate/EndMutate pairs. + return inner +} - if prot&unix.PROT_WRITE != 0 { - runtime.LockOSThread() - C.pthread_jit_write_protect_np(0) - } else { - C.pthread_jit_write_protect_np(1) - runtime.UnlockOSThread() - } - return nil - } +// writeJITCode copies src into dst on MAP_JIT pages. The JIT write-protect +// toggle and I-cache flush happen entirely in C, so the return to Go (which +// executes from the duplicate MAP_JIT text) is always in execute mode. +func writeJITCode(dst, src []byte) { + C.jit_memcpy( + unsafe.Pointer(unsafe.SliceData(dst)), + unsafe.Pointer(unsafe.SliceData(src)), + C.size_t(len(src)), + ) } diff --git a/doc.go b/doc.go index 90b0570..dd84965 100644 --- a/doc.go +++ b/doc.go @@ -5,9 +5,8 @@ // is a fun experiment, but do not use it for production code. // // This project is fundamentally non-portable. OS/Arch support: -// - Full support: Linux/amd64, Windows/amd64, Darwin/amd64, Linux/arm64, Windows/arm64 -// - Might work (untested, but it compiles): FreeBSD/amd64, OpenBSD/amd64, NetBSD/amd64 -// - Known broken: Darwin/arm64 (EACCES errors from mprotect) +// - Full support: Linux, Windows, Darwin/MacOS on amd64 and arm64 +// - Might work (untested, but it compiles): FreeBSD, OpenBSD, NetBSD on amd64 // // Other limitations: // - Relies on internal Go APIs that can break at any time diff --git a/cacheflush_arm64.go b/internal/cacheflush/cacheflush_arm64.go similarity index 86% rename from cacheflush_arm64.go rename to internal/cacheflush/cacheflush_arm64.go index 2166304..60189a6 100644 --- a/cacheflush_arm64.go +++ b/internal/cacheflush/cacheflush_arm64.go @@ -1,6 +1,6 @@ //go:build arm64 -package redefine +package cacheflush import "unsafe" @@ -11,7 +11,7 @@ static void cacheflush(char *start, char *end) { */ import "C" -func cacheflush(buf []byte) { +func Flush(buf []byte) { start := unsafe.Pointer(unsafe.SliceData(buf)) end := unsafe.Pointer(uintptr(len(buf)) + uintptr(start)) C.cacheflush((*C.char)(start), (*C.char)(end)) diff --git a/cacheflush_arm64_nocgo.go b/internal/cacheflush/cacheflush_arm64_nocgo.go similarity index 80% rename from cacheflush_arm64_nocgo.go rename to internal/cacheflush/cacheflush_arm64_nocgo.go index c0590ee..7d45134 100644 --- a/cacheflush_arm64_nocgo.go +++ b/internal/cacheflush/cacheflush_arm64_nocgo.go @@ -1,9 +1,9 @@ //go:build arm64 && !cgo -package redefine +package cacheflush // arm64 requires a C compiler to flush the instruction cache. // Install a C compiler and build with CGO_ENABLED=1. -func cacheflush(buf []byte) { +func Flush(buf []byte) { arm64_requires_cgo_for_instruction_cache_flushing() } diff --git a/cacheflush_fallback.go b/internal/cacheflush/cacheflush_fallback.go similarity index 76% rename from cacheflush_fallback.go rename to internal/cacheflush/cacheflush_fallback.go index 12c6eb1..2371540 100644 --- a/cacheflush_fallback.go +++ b/internal/cacheflush/cacheflush_fallback.go @@ -1,7 +1,7 @@ //go:build !arm64 -package redefine +package cacheflush // This isn't needed on amd64. The arm64 version uses the C builtin which is a // no-op, but avoiding cgo makes cross-compiling easier. -func cacheflush(buf []byte) {} +func Flush(buf []byte) {} diff --git a/internal/mach/vm.go b/internal/mach/vm.go new file mode 100644 index 0000000..5fa507d --- /dev/null +++ b/internal/mach/vm.go @@ -0,0 +1,115 @@ +//go:build darwin && arm64 + +package mach + +/* +#include +#include +*/ +import "C" +import ( + "fmt" + "unsafe" +) + +type KernErr int + +func (e KernErr) Error() string { + // Error strings from https://web.mit.edu/darwin/src/modules/xnu/osfmk/man/vm_remap.html and kern_return.h + switch e { + case C.KERN_INVALID_ADDRESS: + return "Specified address is not currently valid." + case C.KERN_NO_SPACE: + return "There is not enough space in the task's address space to allocate the new region for the memory object." + case C.KERN_PROTECTION_FAILURE: + return "Specified memory is valid, but the backing memory manager is not permitted by the requesting task." + } + return fmt.Sprintf("Unknown error code: %d", e) +} + +const ( + VmProtNone = C.VM_PROT_NONE + VmProtRead = C.VM_PROT_READ + VmProtWrite = C.VM_PROT_WRITE + VmProtExecute = C.VM_PROT_EXECUTE +) + +type VmInfo struct { + Addr unsafe.Pointer + Size uintptr + Prot int + MaxProt int +} + +// VmRemap makes a new virtual memory mapping of srcAddr. If addr is 0 then the +// new mapping will be allocated anywhere, otherwise the page will be requested +// at exactly the given address and may overwrite a previously existing mapping +// at that address. +func VmRemap(addr uintptr, srcAddr uintptr, size uintptr) (*VmInfo, error) { + info := VmInfo{ + Size: size, + } + + var vmAddr C.mach_vm_address_t + vmAddr = C.mach_vm_address_t(addr) + + var flags int + if addr == 0 { + flags |= C.VM_FLAGS_ANYWHERE + } else { + flags |= C.VM_FLAGS_FIXED | C.VM_FLAGS_OVERWRITE + } + + var curProt, maxProt C.vm_prot_t + + ret := C.mach_vm_remap( + C.mach_task_self_, + &vmAddr, + C.mach_vm_address_t(size), + 0, + C.int(flags), + C.mach_task_self_, + C.mach_vm_address_t(srcAddr), + 0, + &curProt, + &maxProt, + C.VM_INHERIT_NONE, + ) + + if ret != 0 { + return nil, KernErr(ret) + } + + info.Addr = unsafe.Pointer(uintptr(vmAddr)) + info.Prot = int(curProt) + info.MaxProt = int(maxProt) + + return &info, nil +} + +// Slice returns a byte slice that uses the backing memory referenced in VmInfo. +func (vmi *VmInfo) Slice() []byte { + if vmi.Addr == nil || vmi.Size == 0 { + return nil + } + return unsafe.Slice((*byte)(vmi.Addr), vmi.Size) +} + +// Unmap deallocates the referenced memory. +func (vmi *VmInfo) Unmap() error { + ret := C.mach_vm_deallocate( + C.mach_task_self_, + C.mach_vm_address_t(uintptr(vmi.Addr)), + C.mach_vm_address_t(vmi.Size), + ) + if ret != 0 { + return KernErr(ret) + } + + vmi.Addr = nil + vmi.Size = 0 + vmi.Prot = 0 + vmi.MaxProt = 0 + + return nil +} diff --git a/internal/mach/vm_test.go b/internal/mach/vm_test.go new file mode 100644 index 0000000..2f8f036 --- /dev/null +++ b/internal/mach/vm_test.go @@ -0,0 +1,68 @@ +//go:build darwin && arm64 + +package mach + +import ( + "syscall" + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" + "golang.org/x/sys/unix" +) + +var pageSize = uintptr(syscall.Getpagesize()) + +func TestRemap(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + // The scenario is two source pages and two destination pages. The + // second destination page is remapped from the source. This tests the + // way we intend to use this function. + src, err := unix.MmapPtr(-1, 0, nil, 2*pageSize, unix.PROT_READ|unix.PROT_WRITE, unix.MAP_ANON|unix.MAP_PRIVATE) + require.NoError(err) + defer unix.MunmapPtr(src, 2*pageSize) + + srcBuf := unsafe.Slice((*byte)(src), 2*pageSize) + for i := range srcBuf { + srcBuf[i] = byte(i) + } + + dest, err := unix.MmapPtr(-1, 0, nil, pageSize, unix.PROT_NONE, unix.MAP_ANON|unix.MAP_PRIVATE|unix.MAP_JIT) + require.NoError(err) + defer unix.MunmapPtr(dest, 2*pageSize) + + destPage1 := unsafe.Slice((*byte)(dest), pageSize) + err = unix.Mprotect(destPage1, unix.PROT_READ|unix.PROT_WRITE) + require.NoError(err) + + for i := range destPage1 { + destPage1[i] = ^byte(i) + } + + remapAddr := uintptr(dest) + pageSize + info, err := VmRemap(remapAddr, uintptr(src)+pageSize, pageSize) + require.NoError(err) + + defer info.Unmap() + + assert.Equal(remapAddr, uintptr(info.Addr)) + assert.Equal(pageSize, info.Size) + assert.Equal(info.Prot, VmProtRead|VmProtWrite) + assert.Equal(info.MaxProt, VmProtRead|VmProtWrite|VmProtExecute) + + destPage2 := info.Slice() + srcPage2 := srcBuf[pageSize:] + + t.Logf("src=0x%x-0x%x dest=0x%x-0x%x", uintptr(src)+pageSize, uintptr(src)+2*pageSize, remapAddr, remapAddr+pageSize) + + assert.Equal(srcPage2, destPage2) + + destPage2[0] = 0x12 + assert.Equal(destPage2[0], srcPage2[0], "writes to dest page are reflected in the source page") + + srcPage2[100] = 0xff + assert.Equal(srcPage2[100], destPage2[100], "writes to source page are reflected in the dest page") +} diff --git a/internal/mach/writeprotect.go b/internal/mach/writeprotect.go new file mode 100644 index 0000000..fc521c4 --- /dev/null +++ b/internal/mach/writeprotect.go @@ -0,0 +1,28 @@ +//go:build darwin && arm64 + +package mach + +import ( + "runtime" +) + +/* +#include +*/ +import "C" + +// JITWriteUnprotect is a wrapper around pthread_jit_write_protect_np to allow +// writes to MAP_JIT allocated data sections. +// +// Because the setting is thread-specific, the goroutine will be locked to its +// OS thread until JITWriteUnprotect is called. +func JITWriteUnprotect() { + runtime.LockOSThread() + C.pthread_jit_write_protect_np(0) +} + +// JITWriteProtect reverses the effects of JITWriteUnprotect. +func JITWriteProtect() { + C.pthread_jit_write_protect_np(1) + runtime.UnlockOSThread() +} diff --git a/internal/static/asm_arm64.s b/internal/static/asm_arm64.s new file mode 100644 index 0000000..2f4d127 --- /dev/null +++ b/internal/static/asm_arm64.s @@ -0,0 +1,19 @@ +#include "go_asm.h" +#include "textflag.h" + +// getFrame returns the value of the frame pointer register. getFrame doesn't +// have a stack frame of its own so this will be the caller's stack frame. +TEXT ·getFrame(SB),NOSPLIT,$0-8 + MOVD R29, ret+8(SP) + RET + +TEXT ·getg(SB),NOSPLIT,$0-8 + MOVD g, ret+8(SP) // g is an alias for x28 + RET + +TEXT ·dupMarker(SB),NOSPLIT,$0-8 + ADR marker, R0 + MOVD R0, ret+8(SP) + RET +marker: + WORD $0 diff --git a/internal/static/duplicate_darwin_arm64.go b/internal/static/duplicate_darwin_arm64.go new file mode 100644 index 0000000..dc7192e --- /dev/null +++ b/internal/static/duplicate_darwin_arm64.go @@ -0,0 +1,260 @@ +//go:build darwin && arm64 + +package static + +import ( + "encoding/binary" + "errors" + "fmt" + "syscall" + "unsafe" + + "github.com/pboyd/redefine/internal/cacheflush" + "github.com/pboyd/redefine/internal/mach" + "golang.org/x/arch/arm64/arm64asm" + "golang.org/x/sys/unix" +) + +var moduledataCopy moduledata + +// duplicate clones the program's static data. +// +// If a duplicate already exists that will be returned instead of making another copy. +func (s *Info) duplicate() (*Info, error) { + if s.dupInfo != nil { + // Already duplicated, don't try and do it again. + return s.dupInfo, nil + } + if s.isDuplicate() { + // This is the duplicate, return this instance. + return s, nil + } + + ptr, err := unix.MmapPtr(-1, 0, unsafe.Pointer(s.End), s.End-s.Start, syscall.PROT_NONE, unix.MAP_ANON|unix.MAP_PRIVATE|unix.MAP_JIT) + if err != nil { + return nil, fmt.Errorf("mmap: %w", err) + } + + s.offset = uintptr(ptr) - s.Start + + s.dupInfo = &Info{ + offset: s.offset, + Start: uintptr(ptr), + End: s.End + s.offset, + datap: &moduledataCopy, + } + + // Copy the entire range [text, rodata) into the duplicate mapping. + // On darwin/arm64 the linker places executable text and read-only + // sections (go.func.*, pclntab, rodata stubs …) together in the + // __TEXT segment, so the copy must extend to datap.rodata, not just + // datap.etext. Pass datap.rodata as textEnd so that fixADRP + // correctly leaves ADRPs targeting pages in [etext, rodata) + // unadjusted — those pages are part of the duplicate. + err = copyText( + s.dupInfo.Start+(s.datap.text-s.Start), + s.datap.text, + s.datap.rodata-s.Start, + s.datap.rodata, + ) + if err != nil { + s.unduplicate() + return nil, fmt.Errorf("text: %w", err) + } + + moduledataCopy = *s.datap + moduledataCopy.text += s.offset + moduledataCopy.etext += s.offset + moduledataCopy.minpc += s.offset + moduledataCopy.maxpc += s.offset + + moduledataCopy.textsectmap = make([]textsect, len(s.datap.textsectmap)) + for i := range moduledataCopy.textsectmap { + moduledataCopy.textsectmap[i] = s.datap.textsectmap[i] + moduledataCopy.textsectmap[i].baseaddr += s.offset + } + + // Register the duplicate moduledata with the Go runtime so that + // findfunc and related functions can locate PCs in the duplicate text. + lastmoduledatap.next = &moduledataCopy + + return s.dupInfo, nil +} + +func copyText(destPtr uintptr, srcPtr uintptr, length uintptr, textEnd uintptr) error { + if length == 0 { + return nil + } + + // pageOffset is the address of srcPtr relative to the start of the + // page. destPtr must have the same value. + pageOffset := srcPtr &^ pageMask + if pageOffset != destPtr&^pageMask { + return errors.New("src and dest have different page offsets") + } + + // offset is the number of bytes between the destination and source + offset := destPtr - srcPtr + + srcPages := unsafe.Slice((*byte)(unsafe.Pointer(srcPtr&pageMask)), int((pageOffset+length+pageSize-1)&pageMask)) + destPages := unsafe.Slice((*byte)(unsafe.Pointer(destPtr&pageMask)), int((pageOffset+length+pageSize-1)&pageMask)) + + // Make our data RWX and keep it that way forever. Writes are blocked + // through the Darwin's pthread_jit_write_protect, not mprotect. + err := unix.Mprotect(destPages, unix.PROT_READ|unix.PROT_WRITE|unix.PROT_EXEC) + if err != nil { + return fmt.Errorf("mprotect: %w", err) + } + + mach.JITWriteUnprotect() + defer mach.JITWriteProtect() + + src := srcPages[pageOffset : pageOffset+length] + dest := destPages[pageOffset : pageOffset+length] + + copy(dest, src) + + // Find the duplicate marker in src, then translate that address to dest and set the value to 1. + *(*uint32)(unsafe.Pointer(uintptr(unsafe.Pointer(dupMarker())) + offset)) = 1 + + if err := fixADRP(dest, srcPtr, textEnd, offset); err != nil { + return fmt.Errorf("fixADRP: %w", err) + } + + cacheflush.Flush(destPages) + + return nil +} + +// unduplicate frees memory allocated by duplicate. +func (s *Info) unduplicate() error { + if s.dupInfo == nil { + return nil + } + + err := unix.MunmapPtr(unsafe.Pointer(s.dupInfo.Start), s.dupInfo.End-s.dupInfo.Start) + if err != nil { + return err + } + + s.dupInfo = nil + + return nil +} + +const ( + // ADR/ADRP is encoded as: + // -------------------------------------------------- + // | P | lo 2 bits | 10000 | hi 19 bits | 5-bit reg | + // -------------------------------------------------- + // Mask for the address: + adrAddressMask = uint32(3<<29 | 0x7ffff<<5) +) + +func fixADRP(code []byte, origText, origEtext, offset uintptr) error { + destBase := uintptr(unsafe.Pointer(unsafe.SliceData(code))) + srcBase := destBase - offset + + // ADRP always uses 4KB page granularity regardless of OS page size. + const adrpPageMask = ^uintptr(0xfff) + origTextPage := origText & adrpPageMask + origEtextPage := (origEtext + 0xfff) & adrpPageMask + + for i := uintptr(0); i < uintptr(len(code)); i += 4 { + raw := code[i : i+4] + inst, err := arm64asm.Decode(raw) + if err != nil { + // Just skip bad instructions. It's probably padding or data. + continue + } + + destPC := destBase + i + srcPC := srcBase + i + + switch inst.Op { + case arm64asm.ADRP: + oldArg := int64(inst.Args[1].(arm64asm.PCRel)) + + // Don't update the address if the target is within the + // original text. We want those to keep the same relative value + // so that they'll point to the new text. + targetPage := uintptr(int64(srcPC&adrpPageMask) + oldArg) + if targetPage >= origTextPage && targetPage < origEtextPage { + continue + } + + newImm := (int64(srcPC&adrpPageMask) + oldArg - int64(destPC&adrpPageMask)) >> 12 + if newImm < -(1<<20) || newImm >= (1<<20) { + return fmt.Errorf("ADRP at byte offset %d: adjusted immediate %d out of 21-bit signed range", i, newImm) + } + newArg := uint32(newImm) + + encoded := binary.LittleEndian.Uint32(raw) &^ adrAddressMask + encoded |= (newArg & 3) << 29 // Lowest 2 bits to bits 30 and 29 + encoded |= ((newArg >> 2) & 0x7ffff) << 5 // Highest 19 bits to bits 23 to 5 + binary.LittleEndian.PutUint32(raw, encoded) + + } + } + return nil +} + +func patchRodataCodePtrs(offset uintptr, moddata *moduledata) error { + if moddata.etext >= moddata.noptrdata { + return nil + } + + mapStart := (moddata.etext + pageSize - 1) & pageMask + mapEnd := moddata.noptrdata & pageMask + if mapStart >= mapEnd { + return nil + } + + entries := make(map[uintptr]struct{}, len(moddata.ftab)) + for _, ft := range moddata.ftab { + entries[moddata.text+uintptr(ft.entryoff)] = struct{}{} + } + + size := mapEnd - mapStart + + tmpPtr, err := unix.MmapPtr(-1, 0, nil, size, + unix.PROT_READ|unix.PROT_WRITE, + unix.MAP_PRIVATE|unix.MAP_ANON) + if err != nil { + return fmt.Errorf("mmap temp rodata (%d bytes): %w", size, err) + } + tmpSlice := unsafe.Slice((*byte)(tmpPtr), int(size)) + copy(tmpSlice, unsafe.Slice((*byte)(unsafe.Pointer(mapStart)), int(size))) + + // ignore pclntable area, because patching those pointers caused crashes. + pclnStart := uintptr(unsafe.Pointer(unsafe.SliceData(moddata.pclntable))) + pclnEnd := pclnStart + uintptr(len(moddata.pclntable)) + + for addr := mapStart; addr+8 <= mapEnd; addr += 8 { + if addr >= pclnStart && addr < pclnEnd { + continue + } + + off := addr - mapStart + val := *(*uintptr)(unsafe.Pointer(&tmpSlice[off])) + if val >= moddata.text && val < moddata.etext { + if _, ok := entries[val]; ok { + *(*uintptr)(unsafe.Pointer(&tmpSlice[off])) = val + offset + } + } + } + + _, err = mach.VmRemap(mapStart, uintptr(tmpPtr), size) + if err != nil { + unix.MunmapPtr(tmpPtr, size) + return fmt.Errorf("vm_remap rodata (%d bytes at %#x): %w", size, mapStart, err) + } + + if err := unix.Mprotect(unsafe.Slice((*byte)(unsafe.Pointer(mapStart)), int(size)), unix.PROT_READ); err != nil { + return fmt.Errorf("mprotect rodata to r: %w", err) + } + + unix.MunmapPtr(tmpPtr, size) + + return nil +} diff --git a/internal/static/duplicate_darwin_arm64_test.go b/internal/static/duplicate_darwin_arm64_test.go new file mode 100644 index 0000000..f9701ab --- /dev/null +++ b/internal/static/duplicate_darwin_arm64_test.go @@ -0,0 +1,172 @@ +//go:build darwin && arm64 + +package static + +import ( + "reflect" + "strconv" + "testing" + "unsafe" + + "github.com/pboyd/redefine/internal/mach" + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestDuplicate(t *testing.T) { + staticVar = 0 + + info := GetInfo() + t.Log("Original (before duplicate)\n" + info.String()) + + dupInfo, err := info.duplicate() + require.NoError(t, err) + + t.Log("Duplicate\n" + dupInfo.String()) + + assert.NotEqual(t, uintptr(0), info.offset) + + t.Run("static data function", func(t *testing.T) { + assert := assert.New(t) + + // Sanity check that staticDataFunc works. + assert.Equal(1, staticDataFunc()) + assert.Equal(2, staticDataFunc()) + + // Now get the copy of staticDataFunc in the duplicated text + // segment. It should update the same static data. + dup := offsetFunc(staticDataFunc, info.offset) + + assert.Equal(3, dup()) + assert.Equal(4, dup()) + + // Finally, make sure that the adds done by dup are reflected + // in the original. + assert.Equal(5, staticDataFunc()) + }) + + t.Run("stack allocation", func(t *testing.T) { + assert := assert.New(t) + dup := offsetFunc(stackFunc, info.offset) + assert.Equal(1, dup(0)) + }) + + t.Run("findfunc", func(t *testing.T) { + assert := assert.New(t) + + dupFindfunc := offsetFunc(findfunc, info.offset) + + fi := dupFindfunc(reflect.ValueOf(staticDataFunc).Pointer() + info.offset) + assert.NotNil(fi._func) + assert.NotNil(fi.datap) + }) + + t.Run("split stack", func(t *testing.T) { + assert := assert.New(t) + + g := getg() + origLo := g.stack.lo + + // Sanity check: stackSplitter really splits the stack + stackSplitter(0, 0) + assert.NotEqual(origLo, g.stack.lo) + + // Another sanity check: we get the same g instance no matter + // where the function runs. + g2 := offsetFunc(getg, info.offset)() + assert.Same(g, g2) + + newLo := g.stack.lo + + offsetFunc(stackSplitter, info.offset)(0, 0) + assert.NotEqual(newLo, g.stack.lo) + }) + + t.Run("FuncSlice", func(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + + text, etext := dupInfo.Text() + + buf, err := info.FuncSlice(staticDataFunc) + require.NoError(err) + + assert.Greater(len(buf), 4) + + bufAddr := uintptr(unsafe.Pointer(unsafe.SliceData(buf))) + assert.True(bufAddr >= text && bufAddr < etext, "pointer to original function in the original instance should still get the duplicate data") + + dupBuf, err := dupInfo.FuncSlice(staticDataFunc) + assert.NoError(err) + + dupBufAddr := uintptr(unsafe.Pointer(unsafe.SliceData(dupBuf))) + assert.Equal(bufAddr, dupBufAddr, "pointer to original function in the duplicate instance should get the duplicate data") + + dupFunc := offsetFunc(staticDataFunc, info.offset) + + dupFuncBuf, err := dupInfo.FuncSlice(dupFunc) + assert.NoError(err) + + dupFuncBufAddr := uintptr(unsafe.Pointer(unsafe.SliceData(dupFuncBuf))) + assert.Equal(bufAddr, dupFuncBufAddr, "pointer to duplicate function in the duplicate instance should get the duplicate data") + }) + + t.Run("Duplicated Text Edits", func(t *testing.T) { + buf, err := info.FuncSlice(uncalledFunc) + require.NoError(t, err) + + // Make sure writes to the duplicated text data don't panic if + // JITWriteUnprotect is called first. + mach.JITWriteUnprotect() + buf[0] = 0 + mach.JITWriteProtect() + }) +} + +var refs []any + +// offsetFunc takes the address of fn and adds offset to it, then derefs that +// address as a function of the same type. +func offsetFunc[T any](fn T, offset uintptr) T { + fnv := reflect.ValueOf(fn) + if fnv.Kind() != reflect.Func { + panic("not a function") + } + + ptr := fnv.Pointer() + offset + ref := &ptr + refs = append(refs, ref) + + return *(*T)(unsafe.Pointer(uintptr(unsafe.Pointer(&ref)))) +} + +var staticVar int + +//go:noinline +func staticDataFunc() int { + staticVar++ + return staticVar +} + +// stackSplitter calls itself recursively until the stack grows. +// +//go:noinline +func stackSplitter(n int, lo uintptr) int { + g := getg() + if lo == 0 { + lo = g.stack.lo + } else if g.stack.lo != lo { + return n + } + + return stackSplitter(n+1, lo) +} + +func stackFunc(n int) int { + str := strconv.Itoa(n) + return len(str) +} + +func uncalledFunc() int { + return 3 +} diff --git a/internal/static/fork.go b/internal/static/fork.go new file mode 100644 index 0000000..9951886 --- /dev/null +++ b/internal/static/fork.go @@ -0,0 +1,66 @@ +//go:build darwin && arm64 + +package static + +import ( + "fmt" + "runtime" +) + +func Fork() error { + if runningInDuplicate() { + // Nothing to do + return nil + } + + info := GetInfo() + _, err := info.duplicate() + if err != nil { + return err + } + + err = patchRodataCodePtrs(info.offset, info.datap) + if err != nil { + return fmt.Errorf("patchRodataCodePtrs: %w", err) + } + + origText, origEtext := info.originalText() + + for f := getFrame(); f != nil; f = f.next { + if f.lr >= origText && f.lr < origEtext { + f.lr += info.offset + } + } + + return nil +} + +type frame struct { + // By convention, Go stores the address of the next frame followed by + // the return address. + next *frame + lr uintptr +} + +func (f *frame) Func() *runtime.Func { + return runtime.FuncForPC(f.lr) +} + +func getFrame() *frame + +type g struct { + stack stack +} + +type stack struct { + lo uintptr + hi uintptr +} + +func getg() *g + +func dupMarker() *uint32 + +func runningInDuplicate() bool { + return *dupMarker() != 0 +} diff --git a/internal/static/fork_fallback.go b/internal/static/fork_fallback.go new file mode 100644 index 0000000..9ca0ea9 --- /dev/null +++ b/internal/static/fork_fallback.go @@ -0,0 +1,7 @@ +//go:build !darwin || !arm64 + +package static + +func Fork() error { + return nil +} diff --git a/internal/static/fork_test.go b/internal/static/fork_test.go new file mode 100644 index 0000000..c70c385 --- /dev/null +++ b/internal/static/fork_test.go @@ -0,0 +1,96 @@ +//go:build darwin && arm64 + +package static + +import ( + "testing" + "unsafe" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +var testForkStatic any + +// forkTestFuncVal is a package-level function variable whose funcval lives in +// rodata. patchRodataCodePtrs must patch its code pointer to the duplicate +// text after Fork(). +// +//go:noinline +func forkTestHelper() int { return 17 } + +var forkTestFuncVal func() int = forkTestHelper + +func TestFork(t *testing.T) { + testForkStatic = int(5) + + assert.False(t, runningInDuplicate()) + + info := GetInfo() + + err := Fork() + require.NoError(t, err) + assert.NotNil(t, info.dupInfo) + + assert.True(t, runningInDuplicate()) + + t.Run("goroutines", func(t *testing.T) { + assert.True(t, runningInDuplicate()) + + ch := make(chan bool, 1) + go func() { + defer close(ch) + ch <- runningInDuplicate() + }() + assert.True(t, <-ch) + }) + + t.Run("type assertions", func(t *testing.T) { + v, ok := testForkStatic.(int) + assert.True(t, ok) + assert.Equal(t, 5, v) + }) + + t.Run("funcval dispatch after fork", func(t *testing.T) { + // A Go func value is a pointer to a funcval whose first word is + // the code entry address. For a package-level function variable + // the funcval lives in rodata, so patchRodataCodePtrs should + // have updated its code pointer to the dupInfo text. + dupText, dupEtext := info.dupInfo.Text() + + // Dereference the func value to get the funcval pointer, then + // read the first word (the code pointer). + fvPtr := *(*uintptr)(unsafe.Pointer(&forkTestFuncVal)) + codePtr := *(*uintptr)(unsafe.Pointer(fvPtr)) + + assert.True(t, codePtr >= dupText && codePtr < dupEtext, + "funcval code pointer 0x%x should be in duplicate text [0x%x, 0x%x) after Fork()", + codePtr, dupText, dupEtext) + + // The funcval must still dispatch correctly. + assert.Equal(t, 17, forkTestFuncVal()) + }) +} + +func TestFrame(t *testing.T) { + assert := assert.New(t) + + f := getFrame() + assert.Greater(f.lr, lastmoduledatap.minpc) + assert.Less(f.lr, lastmoduledatap.maxpc) + + // Check the name of the function that called this test. It may change + // in the future which would break this test. + fn := f.Func() + assert.Equal("testing.tRunner", fn.Name()) + + // There should at least be one additional caller. + f = f.next + assert.NotNil(f) + + for ; f != nil; f = f.next { + t.Logf("name=%s", f.Func().Name()) + assert.Greater(f.lr, lastmoduledatap.minpc) + assert.Less(f.lr, lastmoduledatap.maxpc) + } +} diff --git a/findfunc.go b/internal/static/moddata_go125.go similarity index 71% rename from findfunc.go rename to internal/static/moddata_go125.go index dc66658..31434bc 100644 --- a/findfunc.go +++ b/internal/static/moddata_go125.go @@ -1,6 +1,6 @@ -package redefine +//go:build go1.25 && !go1.26 -import _ "unsafe" +package static type funcInfo struct { *_func @@ -55,7 +55,27 @@ type moduledata struct { rodata uintptr gofunc uintptr // go.func.* - // Struct continues, omitting unused fields. + textsectmap []textsect + + // The following fields exist in the runtime struct but are not used by + // this package. They are included here to correctly place the next field + // at the same offset as in the runtime's moduledata struct. + _typelinks [3]uintptr // []int32 + _itablinks [3]uintptr // []*itab + _ptab [3]uintptr // []ptabEntry + _pluginpath [2]uintptr // string + _pkghashes [3]uintptr // []modulehash + _inittasks [3]uintptr // []*initTask + _modulename [2]uintptr // string + _modulehashes [3]uintptr // []modulehash + _hasmain uint8 + _bad bool + _ [6]byte // padding to align the following bitvectors + _gcdatamask [2]uintptr // bitvector + _gcbssmask [2]uintptr // bitvector + _typemap uintptr // map[typeOff]*_type (a pointer) + + next *moduledata } // pcHeader holds data used by the pclntab lookups. @@ -79,5 +99,8 @@ type functab struct { funcoff uint32 } -//go:linkname findfunc runtime.findfunc -func findfunc(pc uintptr) funcInfo +type textsect struct { + vaddr uintptr // prelinked section vaddr + end uintptr // vaddr + section length + baseaddr uintptr // relocated section address +} diff --git a/internal/static/moddata_go126.go b/internal/static/moddata_go126.go new file mode 100644 index 0000000..6340985 --- /dev/null +++ b/internal/static/moddata_go126.go @@ -0,0 +1,107 @@ +//go:build go1.26 + +package static + +type funcInfo struct { + *_func + datap *moduledata +} + +type _func struct { + //sys.NotInHeap // Only in static data + + entryOff uint32 // start pc, as offset from moduledata.text/pcHeader.textStart + nameOff int32 // function name, as index into moduledata.funcnametab. + + args int32 // in/out args size + deferreturn uint32 // offset of start of a deferreturn call instruction from entry, if any. + + pcsp uint32 + pcfile uint32 + pcln uint32 + npcdata uint32 + cuOffset uint32 // runtime.cutab offset of this function's CU + startLine int32 // line number of start of function (func keyword/TEXT directive) + funcID uint8 // set for certain special runtime functions + flag uint8 + _ [1]byte // pad + nfuncdata uint8 // must be last, must end on a uint32-aligned boundary +} + +// moduledata records information about the layout of the executable +// image. It is written by the linker. Any changes here must be +// matched changes to the code in cmd/link/internal/ld/symtab.go:symtab. +// moduledata is stored in statically allocated non-pointer memory; +// none of the pointers here are visible to the garbage collector. +type moduledata struct { + pcHeader *pcHeader + funcnametab []byte + cutab []uint32 + filetab []byte + pctab []byte + pclntable []byte + ftab []functab + findfunctab uintptr + minpc, maxpc uintptr + + text, etext uintptr + noptrdata, enoptrdata uintptr + data, edata uintptr + bss, ebss uintptr + noptrbss, enoptrbss uintptr + covctrs, ecovctrs uintptr + end, gcdata, gcbss uintptr + types, etypes uintptr + rodata uintptr + gofunc uintptr // go.func.* + epclntab uintptr + + textsectmap []textsect + + // The following fields exist in the runtime struct but are not used by + // this package. They are included here to correctly place the next field + // at the same offset as in the runtime's moduledata struct. + _typelinks [3]uintptr // []int32 + _itablinks [3]uintptr // []*itab + _ptab [3]uintptr // []ptabEntry + _pluginpath [2]uintptr // string + _pkghashes [3]uintptr // []modulehash + _inittasks [3]uintptr // []*initTask + _modulename [2]uintptr // string + _modulehashes [3]uintptr // []modulehash + _hasmain uint8 + _bad bool + _ [6]byte // padding to align the following bitvectors + _gcdatamask [2]uintptr // bitvector + _gcbssmask [2]uintptr // bitvector + _typemap uintptr // map[typeOff]*_type (a pointer) + + next *moduledata +} + +// pcHeader holds data used by the pclntab lookups. +type pcHeader struct { + magic uint32 // 0xFFFFFFF1 + pad1, pad2 uint8 // 0,0 + minLC uint8 // min instruction size + ptrSize uint8 // size of a ptr in bytes + nfunc int // number of functions in the module + nfiles uint // number of entries in the file tab + textStart uintptr // base for function entry PC offsets in this module, equal to moduledata.text + funcnameOffset uintptr // offset to the funcnametab variable from pcHeader + cuOffset uintptr // offset to the cutab variable from pcHeader + filetabOffset uintptr // offset to the filetab variable from pcHeader + pctabOffset uintptr // offset to the pctab variable from pcHeader + pclnOffset uintptr // offset to the pclntab variable from pcHeader +} + +type functab struct { + entryoff uint32 // relative to runtime.text + funcoff uint32 +} + +type textsect struct { + vaddr uintptr // prelinked section vaddr + end uintptr // vaddr + section length + baseaddr uintptr // relocated section address +} diff --git a/internal/static/static.go b/internal/static/static.go new file mode 100644 index 0000000..988ac61 --- /dev/null +++ b/internal/static/static.go @@ -0,0 +1,152 @@ +package static + +import ( + "errors" + "fmt" + "reflect" + "runtime" + "strings" + "sync" + "syscall" + "unsafe" +) + +//go:linkname findfunc runtime.findfunc +func findfunc(pc uintptr) funcInfo + +//go:linkname lastmoduledatap runtime.lastmoduledatap +var lastmoduledatap *moduledata + +var pageSize = uintptr(syscall.Getpagesize()) +var pageMask = ^(pageSize - 1) + +var info Info +var infoInit sync.Once + +func GetInfo() *Info { + infoInit.Do(func() { + // The info is based on the module of this function. This is + // typically the main program, but it may not be under + // -buildmode=plugin or -buildmode=shared. + pc, _, _, _ := runtime.Caller(0) + datap := findfunc(pc).datap + + // Align start and end to page boundaries + start := datap.text & pageMask + length := ((datap.end - start) + pageSize - 1) & pageMask + end := start + length + + info = Info{ + datap: datap, + Start: datap.text & pageMask, + End: end, + } + }) + + return &info +} + +type Info struct { + // Delta from an original address to the duplicate. + offset uintptr + + Start, End uintptr + + datap *moduledata + + dupInfo *Info +} + +func (s *Info) isDuplicate() bool { + // If offset is 0 there is no duplicate at all. And the duplicate field + // is only populated on the original. + return s.offset > 0 && s.dupInfo == nil +} + +// Text returns the address of the beginning and end of the text segment. +func (s *Info) Text() (text uintptr, etext uintptr) { + text = s.datap.text + etext = s.datap.etext + return +} + +func (s *Info) originalText() (text uintptr, etext uintptr) { + text, etext = s.Text() + if s.isDuplicate() { + negOffset := -s.offset + text += negOffset + etext += negOffset + } + return text, etext +} + +// FuncSlice returns a slice containing the machine instructions for a function. +func (s *Info) FuncSlice(fn any) ([]byte, error) { + fnv := reflect.ValueOf(fn) + if fnv.Kind() != reflect.Func { + return nil, fmt.Errorf("not a function, kind: %v", fnv.Kind()) + } + entry := fnv.Pointer() + + datap := findfunc(entry).datap + if datap == nil { + return nil, errors.New("no moduledata for function") + } + + text := datap.text + etext := datap.etext + ftab := datap.ftab + + // To find the length, look at the offsets of every function and find + // the one that comes immediately after this one. + + // TODO: ftab seems to be ordered, can we rely on that to speed this up? + + funcOffset := uint32(entry - text) + length := uint32(etext - entry) + + for _, ft := range ftab { + // Does this function come before the one we're looking for? + if ft.entryoff <= funcOffset { + continue + } + + // Is the distance between these two functions less than what we've seen before? + testLength := ft.entryoff - funcOffset + if testLength < length { + length = testLength + } + } + + // If there's a duplicate, always return the address to its copy of the function. + if s.offset > 0 { + origText, origEtext := s.originalText() + if entry >= origText && entry < origEtext { + entry += s.offset + } + } + + return unsafe.Slice((*byte)(unsafe.Pointer(entry)), length), nil +} + +func (s *Info) String() string { + var b strings.Builder + + fmt.Fprintf(&b, "addr: 0x%x\n", uintptr(unsafe.Pointer(s))) + fmt.Fprintf(&b, "offset: 0x%x\n", s.offset) + fmt.Fprintf(&b, "Start: 0x%x\n", s.Start) + fmt.Fprintf(&b, "End: 0x%x\n", s.End) + fmt.Fprintf(&b, "datap:\n") + fmt.Fprintf(&b, " %-10s 0x%x - 0x%x\n", "text", s.datap.text, s.datap.etext) + fmt.Fprintf(&b, " %-10s 0x%x - 0x%x\n", "noptrdata", s.datap.noptrdata, s.datap.enoptrdata) + fmt.Fprintf(&b, " %-10s 0x%x - 0x%x\n", "data", s.datap.data, s.datap.edata) + fmt.Fprintf(&b, " %-10s 0x%x - 0x%x\n", "bss", s.datap.bss, s.datap.ebss) + fmt.Fprintf(&b, " %-10s 0x%x - 0x%x\n", "noptrbss", s.datap.noptrbss, s.datap.enoptrbss) + fmt.Fprintf(&b, " %-10s 0x%x - 0x%x\n", "covctrs", s.datap.covctrs, s.datap.ecovctrs) + fmt.Fprintf(&b, " %-10s 0x%x - 0x%x\n", "types", s.datap.types, s.datap.etypes) + fmt.Fprintf(&b, " %-10s 0x%x\n", "end", s.datap.end) + fmt.Fprintf(&b, " %-10s 0x%x\n", "gcdata", s.datap.gcdata) + fmt.Fprintf(&b, " %-10s 0x%x\n", "gcbss", s.datap.gcbss) + + return b.String() +} diff --git a/internal/static/static_test.go b/internal/static/static_test.go new file mode 100644 index 0000000..88a8c6d --- /dev/null +++ b/internal/static/static_test.go @@ -0,0 +1,23 @@ +package static + +import ( + "testing" + + "github.com/stretchr/testify/assert" + "github.com/stretchr/testify/require" +) + +func TestFuncSlice(t *testing.T) { + assert := assert.New(t) + require := require.New(t) + info := GetInfo() + + buf, err := info.FuncSlice(testFunc) + require.NoError(err) + + assert.Greater(len(buf), 4) +} + +func testFunc() int { + return 5 +} diff --git a/mmap_flags_darwin.go b/mmap_flags_darwin.go index 01df543..679af26 100644 --- a/mmap_flags_darwin.go +++ b/mmap_flags_darwin.go @@ -2,8 +2,8 @@ package redefine -import "golang.org/x/sys/unix" - -// Darwin has no equivalent to MAP_FIXED_NOREPLACE. But MAP_JIT is required to -// use PROT_WRITE and PROT_EXEC together. -const _MMAP_FLAGS = unix.MAP_JIT +// The allocator uses non-MAP_JIT pages so that pthread_jit_write_protect_np +// is never toggled from MAP_JIT (duplicate text) code. W^X is maintained by +// using PROT_READ|PROT_WRITE during code generation and PROT_READ|PROT_EXEC +// during execution, toggled via regular mprotect. +const _MMAP_FLAGS = 0 diff --git a/redefine.go b/redefine.go index 09a4e48..b6472d5 100644 --- a/redefine.go +++ b/redefine.go @@ -4,7 +4,8 @@ import ( "fmt" "reflect" "sync" - "unsafe" + + "github.com/pboyd/redefine/internal/static" ) var mu sync.RWMutex @@ -144,7 +145,7 @@ func Restore[T any](fn T) error { return fmt.Errorf("unknown function type: %T", cloned) } - code, err := funcSlice(fn) + code, err := static.GetInfo().FuncSlice(fn) if err != nil { return err } @@ -152,90 +153,41 @@ func Restore[T any](fn T) error { return fmt.Errorf("func length mismatch %d != %d", len(code), len(clonedType.originalCode)) } - err = mprotect(code, mprotectRWX) - if err != nil { - return fmt.Errorf("mprotect: %w", err) + if err = applyCodeCopy(code, clonedType.originalCode); err != nil { + return fmt.Errorf("restore code: %w", err) } - defer mprotect(code, mprotectRX) - - copy(code, clonedType.originalCode) clonedType.Free() delete(redefined, fnv.Pointer()) - cacheflush(code) - return nil } // unsafeFunc redefines a function after the safety checks. func unsafeFunc[T any](fn T, newFn any) error { - code, err := funcSlice(fn) - if err != nil { - return err - } - // Locked to prevent simultaneous writes to the map and competing // mprotect calls mu.Lock() defer mu.Unlock() - addr := reflect.ValueOf(fn).Pointer() - if _, ok := redefined[addr]; !ok { - redefined[addr], err = cloneFunc(fn) - if err != nil { - // TODO: Should this be fatal? - return fmt.Errorf("unable to clone function: %w", err) - } - } - - err = mprotect(code, mprotectRWX) + err := static.Fork() if err != nil { - return fmt.Errorf("mprotect: %w", err) + return fmt.Errorf("failed to re-allocate program text segment: %w", err) } - defer mprotect(code, mprotectRX) - err = insertJump(code, reflect.ValueOf(newFn).Pointer()) + code, err := static.GetInfo().FuncSlice(fn) if err != nil { return err } - cacheflush(code) - return nil -} - -// funcSlice returns a slice containing the machine instructions for a function. -func funcSlice(fn any) ([]byte, error) { - fnv := reflect.ValueOf(fn) - if fnv.Kind() != reflect.Func { - return nil, fmt.Errorf("not a function, kind: %v", fnv.Kind()) - } - - entry := fnv.Pointer() - - // To find the length, look at the offsets of every function and find - // the one that comes immediately after this one. - - // TODO: Is there a better way to do this? - // - ftab seems to be ordered so could it find the next entry that way? - // - is the info stored somewhere more conveniently in datap? - - info := findfunc(entry) - funcOffset := uint32(entry - info.datap.text) - length := uint32(info.datap.etext - entry) - - for _, ft := range info.datap.ftab { - // Does this function come before the one we're looking for? - if ft.entryoff <= funcOffset { - continue - } - - // Is the distance between these two functions less than what we've seen before? - testLength := ft.entryoff - funcOffset - if testLength < length { - length = testLength + addr := reflect.ValueOf(fn).Pointer() + if _, ok := redefined[addr]; !ok { + redefined[addr], err = cloneFunc(fn) + if err != nil { + // TODO: Should this be fatal? + return fmt.Errorf("unable to clone function: %w", err) } } - return unsafe.Slice((*byte)(unsafe.Pointer(entry)), length), nil + return applyCodeJump(code, reflect.ValueOf(newFn).Pointer()) } diff --git a/syscalls_darwin_arm64.go b/syscalls_darwin_arm64.go new file mode 100644 index 0000000..82086b5 --- /dev/null +++ b/syscalls_darwin_arm64.go @@ -0,0 +1,50 @@ +//go:build darwin && arm64 + +package redefine + +import ( + "fmt" + "syscall" + "unsafe" +) + +const ( + mprotectExec = 0 // unused: allocator does not use MAP_JIT, no simultaneous W+X needed + mprotectRX = syscall.PROT_READ | syscall.PROT_EXEC + // W^X: BeginMutate uses RW (not RWX); execute permission is restored by EndMutate. + mprotectRWX = syscall.PROT_READ | syscall.PROT_WRITE +) + +// makeRWX and makeRX are no-ops on darwin/arm64. MAP_JIT text patching is +// handled by applyCodeJump / applyCodeCopy which bracket the JIT write in C. +func makeRWX(buf []byte) error { return nil } +func makeRX(buf []byte) error { return nil } + +// applyCodeCopy writes src into dst on MAP_JIT pages via writeJITCode, which +// performs the pthread_jit_write_protect_np toggle and I-cache flush entirely +// in C so that the return to Go is always in execute mode. +func applyCodeCopy(dst, src []byte) error { + writeJITCode(dst[:len(src)], src) + return nil +} + +// applyCodeJump encodes a B instruction targeting dest for execution at +// code's address, then writes it to code via writeJITCode. +func applyCodeJump(code []byte, dest uintptr) error { + srcAddr := uintptr(unsafe.Pointer(unsafe.SliceData(code))) + offset := int64(dest) - int64(srcAddr) + if offset < -(1<<27) || offset >= (1<<27) { + return fmt.Errorf("B target out of range: %d bytes exceeds 128MiB", offset) + } + + // Build the full replacement in a temp buffer (correct for srcAddr), then + // atomically write it to the live MAP_JIT page via jit_memcpy. + tmp := make([]byte, len(code)) + encodeB(tmp, int32(offset)) + // zero-pad the rest (same as insertJump does) + for i := 4; i < len(tmp); i++ { + tmp[i] = 0 + } + writeJITCode(code, tmp) + return nil +} diff --git a/syscalls_unix.go b/syscalls_unix.go index 2d4c392..efcbb1c 100644 --- a/syscalls_unix.go +++ b/syscalls_unix.go @@ -1,11 +1,13 @@ -//go:build linux || darwin || openbsd || netbsd || freebsd +//go:build linux || (darwin && amd64) || openbsd || netbsd || freebsd package redefine import ( + "fmt" "syscall" "unsafe" + "github.com/pboyd/redefine/internal/cacheflush" "golang.org/x/sys/unix" ) @@ -15,6 +17,36 @@ const ( mprotectRWX = syscall.PROT_READ | syscall.PROT_WRITE | syscall.PROT_EXEC ) +func makeRWX(buf []byte) error { + return mprotect(buf, mprotectRWX) +} + +func makeRX(buf []byte) error { + return mprotect(buf, mprotectRX) +} + +func applyCodeCopy(dst, src []byte) error { + if err := makeRWX(dst); err != nil { + return err + } + defer makeRX(dst) + copy(dst[:len(src)], src) + cacheflush.Flush(dst) + return nil +} + +func applyCodeJump(code []byte, dest uintptr) error { + if err := makeRWX(code); err != nil { + return fmt.Errorf("mprotect: %w", err) + } + defer makeRX(code) + if err := insertJump(code, dest); err != nil { + return err + } + cacheflush.Flush(code) + return nil +} + func mprotect(buf []byte, flags int) error { pageSize := syscall.Getpagesize() diff --git a/syscalls_windows.go b/syscalls_windows.go index c79c228..136befe 100644 --- a/syscalls_windows.go +++ b/syscalls_windows.go @@ -3,9 +3,11 @@ package redefine import ( + "fmt" "syscall" "unsafe" + "github.com/pboyd/redefine/internal/cacheflush" "golang.org/x/sys/windows" ) @@ -15,6 +17,36 @@ const ( mprotectRWX = windows.PAGE_EXECUTE_READWRITE ) +func makeRWX(buf []byte) error { + return mprotect(buf, mprotectRWX) +} + +func makeRX(buf []byte) error { + return mprotect(buf, mprotectRX) +} + +func applyCodeCopy(dst, src []byte) error { + if err := makeRWX(dst); err != nil { + return err + } + defer makeRX(dst) + copy(dst[:len(src)], src) + cacheflush.Flush(dst) + return nil +} + +func applyCodeJump(code []byte, dest uintptr) error { + if err := makeRWX(code); err != nil { + return fmt.Errorf("mprotect: %w", err) + } + defer makeRX(code) + if err := insertJump(code, dest); err != nil { + return err + } + cacheflush.Flush(code) + return nil +} + func mprotect(buf []byte, flags int) error { pageSize := syscall.Getpagesize()