From 02448781e37beb9dfa8023c88e6f4749ba27166f Mon Sep 17 00:00:00 2001 From: m96-chan Date: Fri, 6 Mar 2026 16:30:40 +0900 Subject: [PATCH] Add Haskell bindings for oxbitnet (Issue #8) Wrap the oxbitnet-ffi C API via Haskell's FFI (hsc2hs) for type-safe, cross-platform struct layout. High-level API uses MVar + bracket for thread-safe resource management, with IO Bool token callbacks for idiomatic early stopping. Co-Authored-By: Claude Opus 4.6 --- .github/workflows/publish.yml | 47 +++ packages/rust/crates/oxbitnet-haskell/LICENSE | 21 ++ .../rust/crates/oxbitnet-haskell/README.md | 143 ++++++++ .../crates/oxbitnet-haskell/examples/Chat.hs | 64 ++++ .../crates/oxbitnet-haskell/oxbitnet.cabal | 42 +++ .../src/Foreign/OxBitNet/Raw.hsc | 221 ++++++++++++ .../crates/oxbitnet-haskell/src/OxBitNet.hs | 321 ++++++++++++++++++ 7 files changed, 859 insertions(+) create mode 100644 packages/rust/crates/oxbitnet-haskell/LICENSE create mode 100644 packages/rust/crates/oxbitnet-haskell/README.md create mode 100644 packages/rust/crates/oxbitnet-haskell/examples/Chat.hs create mode 100644 packages/rust/crates/oxbitnet-haskell/oxbitnet.cabal create mode 100644 packages/rust/crates/oxbitnet-haskell/src/Foreign/OxBitNet/Raw.hsc create mode 100644 packages/rust/crates/oxbitnet-haskell/src/OxBitNet.hs diff --git a/.github/workflows/publish.yml b/.github/workflows/publish.yml index 0d09281..6da0e14 100644 --- a/.github/workflows/publish.yml +++ b/.github/workflows/publish.yml @@ -37,6 +37,8 @@ jobs: "$(grep -m1 '^version' packages/rust/crates/oxbitnet-python/pyproject.toml | sed 's/.*"\(.*\)"/\1/')" check packages/rust/crates/oxbitnet-java/java/build.gradle.kts \ "$(grep -m1 '^version' packages/rust/crates/oxbitnet-java/java/build.gradle.kts | sed 's/.*"\(.*\)"/\1/')" + check packages/rust/crates/oxbitnet-haskell/oxbitnet.cabal \ + "$(grep -m1 '^version:' packages/rust/crates/oxbitnet-haskell/oxbitnet.cabal | sed 's/.*: *//')" exit $ERRORS # ── crates.io ── @@ -307,3 +309,48 @@ jobs: ORG_GRADLE_PROJECT_signingInMemoryKeyId: ${{ secrets.GPG_KEY_ID }} ORG_GRADLE_PROJECT_signingInMemoryKey: ${{ secrets.GPG_PRIVATE_KEY }} ORG_GRADLE_PROJECT_signingInMemoryKeyPassword: ${{ secrets.GPG_PASSPHRASE }} + + # ── Hackage (Haskell) ── + publish-hackage: + needs: version-check + if: startsWith(github.ref, 'refs/tags/v') + runs-on: ubuntu-latest + steps: + - uses: actions/checkout@v4 + + - uses: haskell-actions/setup@v2 + with: + ghc-version: "9.6" + cabal-version: "3.10" + + - uses: dtolnay/rust-toolchain@stable + + - uses: Swatinem/rust-cache@v2 + with: + workspaces: packages/rust + + - name: Build native library (needed for hsc2hs) + working-directory: packages/rust + run: cargo build -p oxbitnet-ffi --release + + - name: Build Haskell package + working-directory: packages/rust/crates/oxbitnet-haskell + run: | + cabal build all \ + --extra-lib-dirs=../../target/release \ + --extra-include-dirs=../oxbitnet-ffi + + - name: Create source distribution + working-directory: packages/rust/crates/oxbitnet-haskell + run: cabal sdist + + - name: Upload to Hackage + working-directory: packages/rust/crates/oxbitnet-haskell + run: | + cabal upload --publish \ + dist-newstyle/sdist/oxbitnet-*.tar.gz \ + --username "$HACKAGE_USERNAME" \ + --password "$HACKAGE_PASSWORD" + env: + HACKAGE_USERNAME: ${{ secrets.HACKAGE_USERNAME }} + HACKAGE_PASSWORD: ${{ secrets.HACKAGE_PASSWORD }} diff --git a/packages/rust/crates/oxbitnet-haskell/LICENSE b/packages/rust/crates/oxbitnet-haskell/LICENSE new file mode 100644 index 0000000..4dfe64f --- /dev/null +++ b/packages/rust/crates/oxbitnet-haskell/LICENSE @@ -0,0 +1,21 @@ +MIT License + +Copyright (c) 2026 Yusuke Harada + +Permission is hereby granted, free of charge, to any person obtaining a copy +of this software and associated documentation files (the "Software"), to deal +in the Software without restriction, including without limitation the rights +to use, copy, modify, merge, publish, distribute, sublicense, and/or sell +copies of the Software, and to permit persons to whom the Software is +furnished to do so, subject to the following conditions: + +The above copyright notice and this permission notice shall be included in all +copies or substantial portions of the Software. + +THE SOFTWARE IS PROVIDED "AS IS", WITHOUT WARRANTY OF ANY KIND, EXPRESS OR +IMPLIED, INCLUDING BUT NOT LIMITED TO THE WARRANTIES OF MERCHANTABILITY, +FITNESS FOR A PARTICULAR PURPOSE AND NONINFRINGEMENT. IN NO EVENT SHALL THE +AUTHORS OR COPYRIGHT HOLDERS BE LIABLE FOR ANY CLAIM, DAMAGES OR OTHER +LIABILITY, WHETHER IN AN ACTION OF CONTRACT, TORT OR OTHERWISE, ARISING FROM, +OUT OF OR IN CONNECTION WITH THE SOFTWARE OR THE USE OR OTHER DEALINGS IN THE +SOFTWARE. diff --git a/packages/rust/crates/oxbitnet-haskell/README.md b/packages/rust/crates/oxbitnet-haskell/README.md new file mode 100644 index 0000000..210ad44 --- /dev/null +++ b/packages/rust/crates/oxbitnet-haskell/README.md @@ -0,0 +1,143 @@ +# oxbitnet-haskell + +Haskell bindings for [oxbitnet](https://crates.io/crates/oxbitnet) — run [BitNet b1.58](https://github.com/microsoft/BitNet) ternary LLMs with GPU acceleration (wgpu). + +Part of [0xBitNet](https://github.com/m96-chan/0xBitNet). + +## Build + +First, build the native library: + +```bash +cd packages/rust +cargo build -p oxbitnet-ffi --release +``` + +Produces `target/release/liboxbitnet_ffi.so` (Linux) / `.dylib` (macOS) / `oxbitnet_ffi.dll` (Windows). + +Then build the Haskell package: + +```bash +cd packages/rust/crates/oxbitnet-haskell +cabal build all \ + --extra-lib-dirs=../../target/release \ + --extra-include-dirs=../oxbitnet-ffi +``` + +## Quick Start + +```haskell +import OxBitNet + +main :: IO () +main = withBitNet "model.gguf" defaultLoadOptions $ \model -> do + -- Raw prompt + generate model "Hello!" defaultGenerateOptions $ \token -> do + putStr token + return False -- False = continue, True = stop + + -- Chat messages + chat model [userMessage "Hello!"] defaultGenerateOptions $ \token -> do + putStr token + return False +``` + +## API + +### Loading + +```haskell +-- Bracket-based (recommended) +withBitNet "model.gguf" defaultLoadOptions $ \model -> do + ... + +-- Manual load/free +model <- loadBitNet "model.gguf" defaultLoadOptions +-- ...use model... +freeBitNet model + +-- With progress callback +let opts = defaultLoadOptions + { onProgress = Just $ \p -> + putStrLn $ show (lpPhase p) ++ " " ++ show (lpFraction p * 100) ++ "%" + } +withBitNet "model.gguf" opts $ \model -> ... +``` + +### Generation + +```haskell +-- Raw prompt — tokens delivered via callback +generate model "Once upon a time" defaultGenerateOptions $ \token -> do + putStr token + return False -- continue + +-- With custom options +let opts = defaultGenerateOptions + { maxTokens = 512, temperature = 0.7, topK = 40 } +generate model "Hello!" opts $ \token -> do + putStr token + return False + +-- Stop early +generate model "Hello!" defaultGenerateOptions $ \token -> do + putStr token + return (token == "\n") -- stop on newline +``` + +### Chat + +```haskell +let messages = + [ systemMessage "You are a helpful assistant." + , userMessage "What is 2+2?" + ] + +chat model messages defaultGenerateOptions $ \token -> do + putStr token + return False +``` + +### Logger + +```haskell +-- Install before loading any model (can only be called once) +setLogger Info $ \level msg -> + putStrLn $ "[" ++ show level ++ "] " ++ msg +``` + +### Cleanup + +`withBitNet` handles cleanup automatically via `bracket`. For manual management, use `loadBitNet` / `freeBitNet`. Calling `freeBitNet` multiple times is safe. + +## Generation Options + +| Field | Default | Description | +|-------|---------|-------------| +| `maxTokens` | 256 | Maximum tokens to generate | +| `temperature` | 1.0 | Sampling temperature | +| `topK` | 50 | Top-k sampling | +| `repeatPenalty` | 1.1 | Repetition penalty | +| `repeatLastN` | 64 | Window for repetition penalty | + +## Exceptions + +All errors are thrown as `OxBitNetException`: + +- `LoadError String` — model failed to load +- `GenerateError String` — generation failed +- `Disposed` — attempted to use a freed model handle + +## Running the Example + +```bash +cd packages/rust +cargo build -p oxbitnet-ffi --release +cd crates/oxbitnet-haskell +cabal run oxbitnet-chat -- /path/to/model.gguf \ + --extra-lib-dirs=../../target/release +``` + +## License + +MIT diff --git a/packages/rust/crates/oxbitnet-haskell/examples/Chat.hs b/packages/rust/crates/oxbitnet-haskell/examples/Chat.hs new file mode 100644 index 0000000..0e83ca8 --- /dev/null +++ b/packages/rust/crates/oxbitnet-haskell/examples/Chat.hs @@ -0,0 +1,64 @@ +module Main where + +import Control.Monad (unless) +import Data.IORef +import System.Environment (getArgs) +import System.IO (hFlush, stdout, hIsEOF, stdin, hSetBuffering, BufferMode(..)) + +import OxBitNet + +main :: IO () +main = do + hSetBuffering stdout NoBuffering + args <- getArgs + source <- case args of + [s] -> return s + _ -> error "Usage: oxbitnet-chat " + + let loadOpts = defaultLoadOptions + { onProgress = Just $ \p -> do + let phaseName = case lpPhase p of + Download -> "Download" + Parse -> "Parse" + Upload -> "Upload" + pct = lpFraction p * 100.0 :: Double + putStr $ "\r[" ++ phaseName ++ "] " ++ showFFloat1 pct ++ "%" + hFlush stdout + } + + putStrLn "Loading model..." + withBitNet source loadOpts $ \model -> do + putStrLn "\nModel loaded. Type a message (Ctrl-D to quit).\n" + history <- newIORef ([] :: [ChatMessage]) + chatLoop model history + +chatLoop :: BitNet -> IORef [ChatMessage] -> IO () +chatLoop model history = do + putStr "> " + hFlush stdout + eof <- hIsEOF stdin + unless eof $ do + input <- getLine + unless (null input) $ do + modifyIORef' history (++ [userMessage input]) + msgs <- readIORef history + + responseRef <- newIORef "" + chat model msgs defaultGenerateOptions $ \token -> do + putStr token + modifyIORef' responseRef (++ token) + return False + + putStrLn "" + response <- readIORef responseRef + modifyIORef' history (++ [assistantMessage response]) + + chatLoop model history + +-- | Show a Double with 1 decimal place. +showFFloat1 :: Double -> String +showFFloat1 x = + let n = round (x * 10) :: Int + whole = n `div` 10 + frac = n `mod` 10 + in show whole ++ "." ++ show (abs frac) diff --git a/packages/rust/crates/oxbitnet-haskell/oxbitnet.cabal b/packages/rust/crates/oxbitnet-haskell/oxbitnet.cabal new file mode 100644 index 0000000..780395e --- /dev/null +++ b/packages/rust/crates/oxbitnet-haskell/oxbitnet.cabal @@ -0,0 +1,42 @@ +cabal-version: 2.4 +name: oxbitnet +version: 0.5.2 +synopsis: Haskell bindings for oxbitnet — BitNet b1.58 inference with wgpu +description: + Run BitNet b1.58 ternary LLMs with GPU acceleration (wgpu). + . + This package provides Haskell bindings to the @liboxbitnet_ffi@ C library. + You must build @liboxbitnet_ffi@ separately and ensure the linker can find it. + . + See for the full project. + +homepage: https://github.com/m96-chan/0xBitNet +bug-reports: https://github.com/m96-chan/0xBitNet/issues +license: MIT +license-file: LICENSE +author: Yusuke Harada +maintainer: m96.chan.mfmf@gmail.com +category: AI, FFI +extra-source-files: README.md + +library + exposed-modules: + OxBitNet + Foreign.OxBitNet.Raw + build-depends: + base >= 4.14 && < 5 + hs-source-dirs: src + default-language: Haskell2010 + extra-libraries: oxbitnet_ffi + include-dirs: ../../oxbitnet-ffi + includes: oxbitnet.h + ghc-options: -Wall + +executable oxbitnet-chat + main-is: Chat.hs + build-depends: + base >= 4.14 && < 5, + oxbitnet + hs-source-dirs: examples + default-language: Haskell2010 + ghc-options: -Wall -threaded diff --git a/packages/rust/crates/oxbitnet-haskell/src/Foreign/OxBitNet/Raw.hsc b/packages/rust/crates/oxbitnet-haskell/src/Foreign/OxBitNet/Raw.hsc new file mode 100644 index 0000000..ef4f42c --- /dev/null +++ b/packages/rust/crates/oxbitnet-haskell/src/Foreign/OxBitNet/Raw.hsc @@ -0,0 +1,221 @@ +{-# LANGUAGE CApiFFI #-} +{-# LANGUAGE GeneralizedNewtypeDeriving #-} +{-# LANGUAGE PatternSynonyms #-} + +-- | Low-level FFI bindings to @liboxbitnet_ffi@. +module Foreign.OxBitNet.Raw + ( -- * Opaque handle + OxBitNet + + -- * Enums + , LoadPhase + , pattern PhaseDownload + , pattern PhaseParse + , pattern PhaseUpload + , LogLevel + , pattern LogTrace + , pattern LogDebug + , pattern LogInfo + , pattern LogWarn + , pattern LogError + + -- * Structs + , CGenerateOptions(..) + , CLoadProgress(..) + , CLoadOptions(..) + , CChatMessage(..) + + -- * Callback types + , TokenFn + , ProgressFn + , LogFn + , mkTokenFn + , mkProgressFn + , mkLogFn + + -- * FFI functions + , oxbitnet_set_logger + , oxbitnet_default_generate_options + , oxbitnet_default_load_options + , oxbitnet_load + , oxbitnet_free + , oxbitnet_generate + , oxbitnet_chat + , oxbitnet_error_message + ) where + +import Foreign +import Foreign.C + +#include "oxbitnet.h" + +-- -------------------------------------------------------------------- +-- Opaque handle +-- -------------------------------------------------------------------- + +-- | Opaque handle to a loaded BitNet model (never constructed in Haskell). +data OxBitNet + +-- -------------------------------------------------------------------- +-- Enums +-- -------------------------------------------------------------------- + +-- | Load progress phase. +newtype LoadPhase = LoadPhase CInt + deriving (Eq, Ord, Show, Storable) + +pattern PhaseDownload, PhaseParse, PhaseUpload :: LoadPhase +pattern PhaseDownload = LoadPhase #{const Download} +pattern PhaseParse = LoadPhase #{const Parse} +pattern PhaseUpload = LoadPhase #{const Upload} + +-- | Log level. +newtype LogLevel = LogLevel Word8 + deriving (Eq, Ord, Show, Storable) + +pattern LogTrace, LogDebug, LogInfo, LogWarn, LogError :: LogLevel +pattern LogTrace = LogLevel #{const Trace} +pattern LogDebug = LogLevel #{const Debug} +pattern LogInfo = LogLevel #{const Info} +pattern LogWarn = LogLevel #{const Warn} +pattern LogError = LogLevel #{const Error} + +-- -------------------------------------------------------------------- +-- Structs +-- -------------------------------------------------------------------- + +-- | @OxBitNetGenerateOptions@ +data CGenerateOptions = CGenerateOptions + { cgoMaxTokens :: !#{type uintptr_t} + , cgoTemperature :: !CFloat + , cgoTopK :: !#{type uintptr_t} + , cgoRepeatPenalty :: !CFloat + , cgoRepeatLastN :: !#{type uintptr_t} + } deriving (Show) + +instance Storable CGenerateOptions where + sizeOf _ = #{size OxBitNetGenerateOptions} + alignment _ = #{alignment OxBitNetGenerateOptions} + peek p = CGenerateOptions + <$> #{peek OxBitNetGenerateOptions, max_tokens} p + <*> #{peek OxBitNetGenerateOptions, temperature} p + <*> #{peek OxBitNetGenerateOptions, top_k} p + <*> #{peek OxBitNetGenerateOptions, repeat_penalty} p + <*> #{peek OxBitNetGenerateOptions, repeat_last_n} p + poke p v = do + #{poke OxBitNetGenerateOptions, max_tokens} p (cgoMaxTokens v) + #{poke OxBitNetGenerateOptions, temperature} p (cgoTemperature v) + #{poke OxBitNetGenerateOptions, top_k} p (cgoTopK v) + #{poke OxBitNetGenerateOptions, repeat_penalty} p (cgoRepeatPenalty v) + #{poke OxBitNetGenerateOptions, repeat_last_n} p (cgoRepeatLastN v) + +-- | @OxBitNetLoadProgress@ +data CLoadProgress = CLoadProgress + { clpPhase :: !LoadPhase + , clpLoaded :: !Word64 + , clpTotal :: !Word64 + , clpFraction :: !CDouble + } deriving (Show) + +instance Storable CLoadProgress where + sizeOf _ = #{size OxBitNetLoadProgress} + alignment _ = #{alignment OxBitNetLoadProgress} + peek p = CLoadProgress + <$> #{peek OxBitNetLoadProgress, phase} p + <*> #{peek OxBitNetLoadProgress, loaded} p + <*> #{peek OxBitNetLoadProgress, total} p + <*> #{peek OxBitNetLoadProgress, fraction} p + poke p v = do + #{poke OxBitNetLoadProgress, phase} p (clpPhase v) + #{poke OxBitNetLoadProgress, loaded} p (clpLoaded v) + #{poke OxBitNetLoadProgress, total} p (clpTotal v) + #{poke OxBitNetLoadProgress, fraction} p (clpFraction v) + +-- | @OxBitNetLoadOptions@ +data CLoadOptions = CLoadOptions + { cloOnProgress :: !(FunPtr ProgressFn) + , cloProgressUserdata :: !(Ptr ()) + , cloCacheDir :: !(Ptr CChar) + } deriving (Show) + +instance Storable CLoadOptions where + sizeOf _ = #{size OxBitNetLoadOptions} + alignment _ = #{alignment OxBitNetLoadOptions} + peek p = CLoadOptions + <$> #{peek OxBitNetLoadOptions, on_progress} p + <*> #{peek OxBitNetLoadOptions, progress_userdata} p + <*> #{peek OxBitNetLoadOptions, cache_dir} p + poke p v = do + #{poke OxBitNetLoadOptions, on_progress} p (cloOnProgress v) + #{poke OxBitNetLoadOptions, progress_userdata} p (cloProgressUserdata v) + #{poke OxBitNetLoadOptions, cache_dir} p (cloCacheDir v) + +-- | @OxBitNetChatMessage@ +data CChatMessage = CChatMessage + { ccmRole :: !(Ptr CChar) + , ccmContent :: !(Ptr CChar) + } deriving (Show) + +instance Storable CChatMessage where + sizeOf _ = #{size OxBitNetChatMessage} + alignment _ = #{alignment OxBitNetChatMessage} + peek p = CChatMessage + <$> #{peek OxBitNetChatMessage, role} p + <*> #{peek OxBitNetChatMessage, content} p + poke p v = do + #{poke OxBitNetChatMessage, role} p (ccmRole v) + #{poke OxBitNetChatMessage, content} p (ccmContent v) + +-- -------------------------------------------------------------------- +-- Callback types +-- -------------------------------------------------------------------- + +-- | Token callback: @int32_t (*)(const char *token, uintptr_t len, void *userdata)@ +type TokenFn = Ptr CChar -> #{type uintptr_t} -> Ptr () -> IO Int32 + +-- | Progress callback: @void (*)(const OxBitNetLoadProgress *progress, void *userdata)@ +type ProgressFn = Ptr CLoadProgress -> Ptr () -> IO () + +-- | Log callback: @void (*)(OxBitNetLogLevel level, const char *message, uintptr_t len, void *userdata)@ +type LogFn = Word8 -> Ptr CChar -> #{type uintptr_t} -> Ptr () -> IO () + +foreign import ccall "wrapper" + mkTokenFn :: TokenFn -> IO (FunPtr TokenFn) + +foreign import ccall "wrapper" + mkProgressFn :: ProgressFn -> IO (FunPtr ProgressFn) + +foreign import ccall "wrapper" + mkLogFn :: LogFn -> IO (FunPtr LogFn) + +-- -------------------------------------------------------------------- +-- FFI imports +-- -------------------------------------------------------------------- + +-- safe: long-running / callbacks +foreign import ccall safe "oxbitnet_load" + oxbitnet_load :: Ptr CChar -> Ptr CLoadOptions -> IO (Ptr OxBitNet) + +foreign import ccall safe "oxbitnet_generate" + oxbitnet_generate :: Ptr OxBitNet -> Ptr CChar -> Ptr CGenerateOptions + -> FunPtr TokenFn -> Ptr () -> IO Int32 + +foreign import ccall safe "oxbitnet_chat" + oxbitnet_chat :: Ptr OxBitNet -> Ptr CChatMessage -> #{type uintptr_t} + -> Ptr CGenerateOptions -> FunPtr TokenFn -> Ptr () -> IO Int32 + +foreign import ccall safe "oxbitnet_free" + oxbitnet_free :: Ptr OxBitNet -> IO () + +-- unsafe: quick / no callbacks +foreign import ccall unsafe "oxbitnet_default_generate_options" + oxbitnet_default_generate_options :: IO CGenerateOptions + +foreign import ccall unsafe "oxbitnet_default_load_options" + oxbitnet_default_load_options :: IO CLoadOptions + +foreign import ccall unsafe "oxbitnet_error_message" + oxbitnet_error_message :: IO (Ptr CChar) + +foreign import ccall safe "oxbitnet_set_logger" + oxbitnet_set_logger :: FunPtr LogFn -> Ptr () -> Word8 -> IO () diff --git a/packages/rust/crates/oxbitnet-haskell/src/OxBitNet.hs b/packages/rust/crates/oxbitnet-haskell/src/OxBitNet.hs new file mode 100644 index 0000000..fef31c3 --- /dev/null +++ b/packages/rust/crates/oxbitnet-haskell/src/OxBitNet.hs @@ -0,0 +1,321 @@ +{-# LANGUAGE DeriveAnyClass #-} +{-# LANGUAGE DeriveGeneric #-} + +-- | High-level, managed Haskell API for oxbitnet. +-- +-- Typical usage: +-- +-- @ +-- 'withBitNet' "model.gguf" 'defaultLoadOptions' $ \\model -> do +-- 'generate' model "Hello!" 'defaultGenerateOptions' $ \\token -> do +-- putStr token +-- return False -- False = continue +-- @ +module OxBitNet + ( -- * Model handle + BitNet + , withBitNet + , loadBitNet + , freeBitNet + + -- * Generation + , generate + , chat + + -- * Logger + , setLogger + + -- * Options + , GenerateOptions(..) + , defaultGenerateOptions + , LoadOptions(..) + , defaultLoadOptions + + -- * Types + , ChatMessage(..) + , userMessage + , assistantMessage + , systemMessage + , LoadProgress(..) + , LoadPhase(..) + , LogLevel(..) + + -- * Exceptions + , OxBitNetException(..) + ) where + +import Control.Concurrent.MVar +import Control.Exception (Exception, bracket, throwIO) +import Control.Monad (when) +import Data.IORef +import Data.Word (Word8, Word64) +import Foreign hiding (with) +import Foreign.C +import GHC.Generics (Generic) + +import qualified Foreign.OxBitNet.Raw as Raw + +-- -------------------------------------------------------------------- +-- Types +-- -------------------------------------------------------------------- + +-- | Thread-safe handle to a loaded BitNet model. +newtype BitNet = BitNet (MVar (Maybe (Ptr Raw.OxBitNet))) + +-- | Generation options. +data GenerateOptions = GenerateOptions + { maxTokens :: !Int -- ^ Maximum number of tokens to generate (default: 256). + , temperature :: !Float -- ^ Sampling temperature (default: 1.0). + , topK :: !Int -- ^ Top-k sampling parameter (default: 50). + , repeatPenalty :: !Float -- ^ Repetition penalty (default: 1.1). + , repeatLastN :: !Int -- ^ Window size for repetition penalty (default: 64). + } deriving (Show, Eq) + +-- | Sensible defaults for generation. +defaultGenerateOptions :: GenerateOptions +defaultGenerateOptions = GenerateOptions + { maxTokens = 256 + , temperature = 1.0 + , topK = 50 + , repeatPenalty = 1.1 + , repeatLastN = 64 + } + +-- | Options for loading a model. +data LoadOptions = LoadOptions + { onProgress :: Maybe (LoadProgress -> IO ()) + -- ^ Progress callback (optional). + , cacheDir :: Maybe String + -- ^ Cache directory path (optional). + } + +-- | Default load options (no progress callback, no cache dir). +defaultLoadOptions :: LoadOptions +defaultLoadOptions = LoadOptions Nothing Nothing + +-- | Load progress information. +data LoadProgress = LoadProgress + { lpPhase :: !LoadPhase + , lpLoaded :: !Word64 + , lpTotal :: !Word64 + , lpFraction :: !Double + } deriving (Show) + +-- | Load progress phase. +data LoadPhase + = Download + | Parse + | Upload + deriving (Show, Eq, Ord, Enum, Bounded) + +-- | Log severity level. +data LogLevel + = Trace + | Debug + | Info + | Warn + | Error + deriving (Show, Eq, Ord, Enum, Bounded) + +-- | A chat message with role and content. +data ChatMessage = ChatMessage + { cmRole :: !String + , cmContent :: !String + } deriving (Show, Eq) + +-- | Construct a user message. +userMessage :: String -> ChatMessage +userMessage = ChatMessage "user" + +-- | Construct an assistant message. +assistantMessage :: String -> ChatMessage +assistantMessage = ChatMessage "assistant" + +-- | Construct a system message. +systemMessage :: String -> ChatMessage +systemMessage = ChatMessage "system" + +-- | Exceptions thrown by oxbitnet operations. +data OxBitNetException + = LoadError String + | GenerateError String + | Disposed + deriving (Show, Eq, Generic, Exception) + +-- -------------------------------------------------------------------- +-- Model lifecycle +-- -------------------------------------------------------------------- + +-- | Load a model, run an action, then free it. Ensures cleanup on exception. +withBitNet :: String -> LoadOptions -> (BitNet -> IO a) -> IO a +withBitNet source opts = bracket (loadBitNet source opts) freeBitNet + +-- | Load a model from a URL or local file path. +loadBitNet :: String -> LoadOptions -> IO BitNet +loadBitNet source opts = + withCString source $ \cSource -> do + cOpts <- Raw.oxbitnet_default_load_options + + -- Set up progress callback + progressFunPtr <- case onProgress opts of + Nothing -> return nullFunPtr + Just cb -> Raw.mkProgressFn $ \pProgress _userdata -> do + raw <- peek pProgress + let progress = LoadProgress + { lpPhase = toLoadPhase (Raw.clpPhase raw) + , lpLoaded = Raw.clpLoaded raw + , lpTotal = Raw.clpTotal raw + , lpFraction = realToFrac (Raw.clpFraction raw) + } + cb progress + + let cOpts' = cOpts + { Raw.cloOnProgress = progressFunPtr + } + + -- Set up cache dir + result <- case cacheDir opts of + Nothing -> alloca $ \pOpts -> do + poke pOpts cOpts' + Raw.oxbitnet_load cSource pOpts + Just dir -> withCString dir $ \cDir -> alloca $ \pOpts -> do + poke pOpts cOpts' { Raw.cloCacheDir = cDir } + Raw.oxbitnet_load cSource pOpts + + -- Free progress FunPtr (no longer needed after load completes) + when (progressFunPtr /= nullFunPtr) $ + freeHaskellFunPtr progressFunPtr + + if result == nullPtr + then do + errPtr <- Raw.oxbitnet_error_message + msg <- if errPtr == nullPtr + then return "Failed to load model" + else peekCString errPtr + throwIO (LoadError msg) + else BitNet <$> newMVar (Just result) + +-- | Free a model handle. Safe to call multiple times. +freeBitNet :: BitNet -> IO () +freeBitNet (BitNet mvar) = modifyMVar_ mvar $ \mHandle -> do + case mHandle of + Nothing -> return Nothing + Just ptr -> do + Raw.oxbitnet_free ptr + return Nothing + +-- | Acquire the raw pointer, throwing 'Disposed' if already freed. +withHandle :: BitNet -> (Ptr Raw.OxBitNet -> IO a) -> IO a +withHandle (BitNet mvar) action = withMVar mvar $ \mHandle -> + case mHandle of + Nothing -> throwIO Disposed + Just ptr -> action ptr + +-- -------------------------------------------------------------------- +-- Generation +-- -------------------------------------------------------------------- + +-- | Generate text from a raw prompt. +-- +-- The callback receives each token as a 'String'. Return 'True' to stop +-- generation early, 'False' to continue. +generate :: BitNet -> String -> GenerateOptions -> (String -> IO Bool) -> IO () +generate model prompt opts onToken = + withHandle model $ \handle -> + withCString prompt $ \cPrompt -> + withCGenerateOptions opts $ \pOpts -> do + tokenFn <- Raw.mkTokenFn $ \cStr len _userdata -> do + str <- peekCStringLen (cStr, fromIntegral len) + stop <- onToken str + return (if stop then 1 else 0) + ret <- Raw.oxbitnet_generate handle cPrompt pOpts tokenFn nullPtr + freeHaskellFunPtr tokenFn + when (ret /= 0) $ do + errPtr <- Raw.oxbitnet_error_message + msg <- if errPtr == nullPtr + then return "Generate failed" + else peekCString errPtr + throwIO (GenerateError msg) + +-- | Generate text from chat messages. +-- +-- The callback receives each token as a 'String'. Return 'True' to stop +-- generation early, 'False' to continue. +chat :: BitNet -> [ChatMessage] -> GenerateOptions -> (String -> IO Bool) -> IO () +chat model messages opts onToken = + withHandle model $ \handle -> + withCChatMessages messages $ \pMsgs numMsgs -> + withCGenerateOptions opts $ \pOpts -> do + tokenFn <- Raw.mkTokenFn $ \cStr len _userdata -> do + str <- peekCStringLen (cStr, fromIntegral len) + stop <- onToken str + return (if stop then 1 else 0) + ret <- Raw.oxbitnet_chat handle pMsgs (fromIntegral numMsgs) pOpts tokenFn nullPtr + freeHaskellFunPtr tokenFn + when (ret /= 0) $ do + errPtr <- Raw.oxbitnet_error_message + msg <- if errPtr == nullPtr + then return "Chat failed" + else peekCString errPtr + throwIO (GenerateError msg) + +-- -------------------------------------------------------------------- +-- Logger +-- -------------------------------------------------------------------- + +-- | Install a global logger. Must be called before loading any model. +-- Can only be called once; subsequent calls are no-ops. +-- +-- The 'FunPtr' is intentionally leaked (process-lifetime, matching the C API +-- contract that the logger is installed once and lives forever). +setLogger :: LogLevel -> (LogLevel -> String -> IO ()) -> IO () +setLogger minLevel cb = do + funPtr <- Raw.mkLogFn $ \level cStr len _userdata -> do + str <- peekCStringLen (cStr, fromIntegral len) + cb (toLogLevel level) str + Raw.oxbitnet_set_logger funPtr nullPtr (fromIntegral (fromEnum minLevel)) + +-- -------------------------------------------------------------------- +-- Internal helpers +-- -------------------------------------------------------------------- + +toLoadPhase :: Raw.LoadPhase -> LoadPhase +toLoadPhase p + | p == Raw.PhaseDownload = Download + | p == Raw.PhaseParse = Parse + | p == Raw.PhaseUpload = Upload + | otherwise = Download -- fallback + +toLogLevel :: Word8 -> LogLevel +toLogLevel n + | n <= 0 = Trace + | n == 1 = Debug + | n == 2 = Info + | n == 3 = Warn + | otherwise = Error + +withCGenerateOptions :: GenerateOptions -> (Ptr Raw.CGenerateOptions -> IO a) -> IO a +withCGenerateOptions opts action = alloca $ \p -> do + poke p Raw.CGenerateOptions + { Raw.cgoMaxTokens = fromIntegral (maxTokens opts) + , Raw.cgoTemperature = realToFrac (temperature opts) + , Raw.cgoTopK = fromIntegral (topK opts) + , Raw.cgoRepeatPenalty = realToFrac (repeatPenalty opts) + , Raw.cgoRepeatLastN = fromIntegral (repeatLastN opts) + } + action p + +-- | Marshal a list of 'ChatMessage' into a C array, using nested +-- 'withCString' to keep all strings alive for the duration of the action. +withCChatMessages :: [ChatMessage] -> (Ptr Raw.CChatMessage -> Int -> IO a) -> IO a +withCChatMessages msgs action = go msgs [] where + go [] acc = do + let n = length acc + cMsgs = reverse acc + allocaArray n $ \pArr -> do + pokeArray pArr cMsgs + action pArr n + go (ChatMessage role content : rest) acc = + withCString role $ \cRole -> + withCString content $ \cContent -> + go rest (Raw.CChatMessage cRole cContent : acc)