diff --git a/.env.example b/.env.example deleted file mode 100644 index 1179511..0000000 --- a/.env.example +++ /dev/null @@ -1,12 +0,0 @@ -# ── LLM Validation ──────────────────────────────────────────────────────────── -# Validates extracted conversation samples for coherence and quality before -# writing the dataset. Enabled by default when ANTHROPIC_API_KEY is set. -# Set to false to skip validation entirely (faster, no API calls). -DIALOGSMITH_LLM_VALIDATE=true - -# Model used for validation scoring (defaults to claude-haiku-4-5-20251001). -# A fast, cheap model is recommended here — the validator runs once per sample. -DIALOGSMITH_LLM_MODEL=claude-haiku-4-5-20251001 - -# Your Anthropic API key. Required when DIALOGSMITH_LLM_VALIDATE=true. -ANTHROPIC_API_KEY=your_api_key_here diff --git a/.gitignore b/.gitignore index 542add2..f93c2fe 100644 --- a/.gitignore +++ b/.gitignore @@ -10,6 +10,14 @@ ChatExport*/ # Tracked project config (re-include despite the broad *.json rule above) !configs/dataset_info.json +# Synthetic demo input — safe to commit and needed for the reproducible demo. +# (Generated demo outputs like demo/sample_sharegpt.json stay ignored via *.json.) +!demo/sample_export.json + +# Personal training/export overrides — copy a tracked config to *.local.yaml and +# edit that for your own model/hardware; it stays out of git. +configs/*.local.yaml + # Python cache / bytecode __pycache__/ *.py[cod] diff --git a/README.md b/README.md index 292694b..17bbc15 100644 --- a/README.md +++ b/README.md @@ -1,135 +1,249 @@ -# Doppelganger – Fine-Tune Models on Your Chat History +

Doppelganger

-**Doppelganger** lets you fine-tune large language models (LLMs) like Qwen on your own chat -conversations. Built on top of [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory), it -formats your data into the ShareGPT format for supervised fine-tuning (SFT). +

+ Fine-tune an LLM on your own chat history to mimic how you write +

-Ingestion is **source-agnostic**: a small adapter parses each platform's export into a normalized -message stream, and the rest of the pipeline (sessionizing, turn-merging, optional quality -validation, ShareGPT formatting) is shared. **Telegram** is supported today; other sources -(WhatsApp, etc.) are planned and slot in as drop-in adapters — see [issue #9](https://github.com/NotYuSheng/Doppelganger/issues/9). +

+ Features • + Quick Start • + Usage • + Fine-Tuning • + Privacy +

-## Purpose +

+ Python + PyTorch + LLaMA-Factory + OpenAI-compatible + License: MIT +

-Fine-tuning on chat data can capture aspects of your text style, including: +

+ Doppelganger CLI: ingest a chat export and scan it for sensitive data before training +

-* Writing tone, vocabulary, and phrasing -* Typical response lengths and structure -* Repeated expressions or idioms -* Conversational flow and habits +--- -However, this method **won’t replicate your deeper beliefs, private memories, or behavior outside the chat**. It reflects how you write — not necessarily how you think. +Doppelganger fine-tunes large language models (like Qwen) on your own chat conversations, capturing how *you* write. Built on top of [LLaMA-Factory](https://github.com/hiyouga/LLaMA-Factory), it turns a raw chat export into a [ShareGPT](https://github.com/hiyouga/LLaMA-Factory/blob/main/data/README.md)-formatted dataset for supervised fine-tuning (SFT), then trains a LoRA adapter on it. -For stronger emulation, consider incorporating: +Ingestion is **source-agnostic**: a small adapter parses each platform's export into a normalized message stream, and the rest of the pipeline (sessionizing, turn-merging, sensitive-data scanning, optional quality auditing, ShareGPT formatting) is shared. **Telegram** is supported today, with **WhatsApp**, **Discord**, and other platforms planned — each slots in as a drop-in adapter. -* Additional sources like emails or forum posts -* Clear prompt instructions during inference -* Domain-specific datasets (e.g., technical messages, inside jokes) +> [!CAUTION] +> **Your chat history is sensitive data, and you are responsible for it.** A model fine-tuned on it can memorize and later reproduce personal identifiers, private conversations, credentials, and messages written by other people in your chats. The built-in [sensitive-data scanning](#privacy--sensitive-data) is a **safety net, not a guarantee** — both regex and LLM detection miss real cases and raise false positives. Before training, sharing, or deploying anything: **review the dataset yourself**, get consent from others whose messages are included (especially in group chats), and comply with applicable privacy laws. Treat trained adapters and merged checkpoints as sensitive too — they can leak the data they were trained on. -## Warning: Risk of Sensitive Data Exposure +> [!IMPORTANT] +> **This is a for-fun, experimental project — not a production tool.** A model that imitates a real person can be misused for impersonation, deception, or social engineering, and it will happily generate convincing messages that person never actually wrote. Don't present its output as genuinely from anyone, and don't rely on it for anything that matters. Enjoy it responsibly. -Fine-tuning on real chat history may unintentionally encode: +Fine-tuning on your chats can capture your: -* Personal identifiers (names, locations, contact info) -* Confidential conversations -* Sensitive or offensive content +- **Writing tone, vocabulary, and phrasing** +- **Typical response lengths and structure** +- **Repeated expressions and idioms** +- **Conversational flow and habits** -> **Always review and sanitize your exported dataset (`result.json`) before training.** -> You are responsible for ensuring compliance with privacy laws and personal data protection. +> **Note**: This reflects *how you write*, not how you think — it **won't** replicate your deeper beliefs, private memories, or behaviour outside the chat. For stronger emulation, add other sources (emails, forum posts), clear prompt instructions at inference, and domain-specific data (technical messages, inside jokes). -### Keeping your data out of git +## Features -Your chat export and any generated datasets are ignored by `.gitignore` -(`result.json`, `*.json`, `*.jsonl`, `DataExport*/`, `*.session`, `.env`, plus Telegram -media/contacts such as `*.vcard`, `*.tgs`, `*.webp`, `*.ogg`/`*.oga`). Generic -media (`.jpg`, `.mp4`, …) lives inside `DataExport*/`, which is ignored -wholesale. As an extra safeguard, a pre-commit hook refuses to commit these -files even if they are force-added. Enable it once per clone: +| Feature | Description | +|---------|-------------| +| **Source-agnostic ingestion** | One adapter per platform parses an export into a normalized message stream; the rest of the pipeline is shared. Telegram today; others drop in without touching the core. | +| **Conversation reconstruction** | Sessionizes messages by silence gaps **and** reply links, merges consecutive turns, and (optionally) preserves per-speaker labels in group chats. | +| **Sensitive-data scan** | Non-destructive regex scan over the built conversations — email, payment cards (checksum-validated), IP/MAC, API keys, plus pluggable country ID packs. Writes an audit report; you decide what to remove. | +| **LLM redaction** *(optional)* | An OpenAI-compatible model flags context-dependent PII (names, secrets) regex misses, into the same report and apply step. Local-first by design. | +| **LLM quality auditor** *(optional)* | Scores each conversation for coherence, quality, and pairing; drops weak samples and splits over-merged ones. | +| **ShareGPT output** | Emits exactly the format LLaMA-Factory consumes for SFT, with loss masked to your own turns. | +| **LoRA fine-tuning** | Ready-made train / export / chat configs; swap the base model in one place. | -```bash -git config core.hooksPath hooks -``` +## Quick Start -To deliberately commit a blocked file, bypass the hook with `git commit --no-verify`. +### Prerequisites + +| Software | Version | Purpose | +|----------|---------|---------| +| Python | 3.11–3.13 | Required by LLaMA-Factory 0.9.4 | +| PyTorch | CUDA build for your GPU | Training (see the [install matrix](https://pytorch.org/get-started/locally/)) | +| git | Latest | Clone + the dataset-hygiene pre-commit hook | +| LLM server | Any OpenAI-compatible API | **Optional** — quality auditing & LLM redaction (Ollama, vLLM, LM Studio, OpenAI) | -## Requirements +A CUDA-capable GPU is needed for training. Ingestion (parsing → dataset) runs fine on CPU. -* **Python 3.11–3.13** (required by LLaMA-Factory 0.9.4) -* A CUDA-capable GPU for training, with a matching [PyTorch build](https://pytorch.org/get-started/locally/) -* `git` +### Installation -## Export Your Telegram Chat +**1. Export your Telegram chat** -1. Open **Telegram Desktop**. -2. Go to: `Settings > Advanced > Export Telegram Data`. -3. Select your personal chat or group to export. -4. Ensure **JSON** format is selected (not HTML). -5. Place the exported `result.json` file into: +In **Telegram Desktop**: `Settings > Advanced > Export Telegram Data`. Select your chat(s), choose **JSON** format (not HTML), and place the result here: ``` Doppelganger/ -├── data/ -│ └── result.json ← Place here +└── data/ + └── result.json ← place your export here ``` -## Setup +**2. Clone and run setup** -The setup scripts create a virtual environment, install pinned dependencies -(LLaMA-Factory **0.9.4**), and process your export into `data/chat_sharegpt.json`. - -**Linux / macOS:** +The setup scripts create a virtual environment, install pinned dependencies (LLaMA-Factory **0.9.4**), create your `.env`, and process the export into `data/chat_sharegpt.json`. ```bash -./setup.sh -``` - -**Windows** (from **Command Prompt**, not PowerShell): +git clone https://github.com/NotYuSheng/Doppelganger.git +cd Doppelganger -```cmd -setup.bat +./setup.sh # Linux / macOS +setup.bat # Windows (from Command Prompt, not PowerShell) ``` -Prefer to do it manually? The scripts are thin wrappers around: +
+Prefer to run it manually? ```bash python -m venv venv -# activate: source venv/bin/activate (Windows: venv\Scripts\activate) +source venv/bin/activate # Windows: venv\Scripts\activate pip install -r requirements.txt python -m ingest --source telegram ``` +
-### Ingestion options +**3. (Optional) Configure LLM features** -`python -m ingest` turns a raw export into a dataset. Useful flags: +The core pipeline needs no LLM. To *also* enable the quality auditor and LLM redaction, copy `example.env` to `.env` (the setup scripts do this) and point it at a **local** OpenAI-compatible server (vLLM, LM Studio, llama.cpp) so your chat data stays on your machine: -| Flag | Default | Description | -| --------------------- | ------------------------- | ------------------------------------------------------ | -| `--source` | `telegram` | Chat source to parse (more planned) | -| `--input` | `./data/result.json` | Path to the raw export | -| `--format` | `sharegpt` | `sharegpt` (for training) or `jsonl` (intermediate) | -| `--self-name` | auto-detected | Override which sender is "you" | -| `--conversation-gap` | `3600` | Seconds of silence that start a new conversation | -| `--message-chain` | `30` | Max seconds between same-sender messages to merge | +```dotenv +LLM_VALIDATE=true +LLM_API_BASE_URL=http://localhost:8000/v1 # vLLM (LM Studio uses :1234/v1) +LLM_MODEL=Qwen/Qwen2.5-7B-Instruct # the model your server serves +LLM_API_KEY=local # local servers accept any value +``` -### Optional: LLM quality validation +**4. Fine-tune** -Each extracted conversation can be scored for coherence and quality, dropping weak samples before -training. It is enabled automatically when `ANTHROPIC_API_KEY` is set. Copy `.env.example` to `.env` -and fill it in (the setup scripts do this for you): +```bash +source venv/bin/activate +llamafactory-cli train configs/train_lora.yaml +``` -```dotenv -DIALOGSMITH_LLM_VALIDATE=true -DIALOGSMITH_LLM_MODEL=claude-haiku-4-5-20251001 -ANTHROPIC_API_KEY=your_api_key_here +## Usage + +`python -m ingest` turns a raw export into a training-ready dataset. Useful flags: + +| Flag | Default | Description | +|------|---------|-------------| +| `--source` | `telegram` | Chat source to parse (more planned) | +| `--input` | `./data/result.json` | Path to the raw export | +| `--format` | `sharegpt` | `sharegpt` (for training) or `jsonl` (intermediate) | +| `--self-name` | auto-detected | Override which sender is "you" | +| `--conversation-gap` | `3600` | Seconds of silence that start a new conversation | +| `--message-chain` | `30` | Max seconds between same-sender messages to merge into one turn | +| `--multi-speaker` | off | In group chats, keep and label each sender on the human side (your turns are never labelled) | +| `--no-audit` | off | Master off-switch: skip **all** auditing (regex scan + LLM validation) and just build the dataset | +| `--skip-redact-scan` | off | Skip only the regex sensitive-data scan | +| `--skip-validation` | off | Skip only the LLM quality validation | + +### Optional: LLM quality auditing + +Each extracted conversation can be scored for **coherence, quality, and pairing**, dropping or splitting weak samples before training. It talks to a **local** OpenAI-compatible server (vLLM, LM Studio, llama.cpp) so your chat data stays on your machine. It's enabled automatically when `LLM_API_KEY` or `LLM_API_BASE_URL` is set (configure it in `.env`, step 3 above). + +To turn it off, set `LLM_VALIDATE=false` in `.env` (persistent) or pass `--skip-validation` for a single run. To disable **all** auditing at once — both this and the regex scan — use `--no-audit`. + +### Running a local LLM (recommended: LM Studio) + +The LLM features are designed to run against a **local** model so your chat data never leaves your machine. [LM Studio](https://lmstudio.ai) is the easiest way to get one running with a click-through UI: + +1. Install **LM Studio** and use its search to download a model (see the table below). +2. Open the **Developer** tab → **Start Server**. It serves an OpenAI-compatible API at `http://localhost:1234/v1`. +3. In `.env`, set: + ```dotenv + LLM_VALIDATE=true + LLM_API_BASE_URL=http://localhost:1234/v1 + LLM_MODEL= + LLM_API_KEY=local + ``` + +(Prefer the CLI? **vLLM** serves the same API at `http://localhost:8000/v1` with `--model Qwen/Qwen2.5-7B-Instruct`. **Ollama** also works at `http://localhost:11434/v1`.) + +**Which model?** The auditor/redactor just needs solid instruction-following and JSON output — a small model is plenty. Pick by your hardware (GGUF quants in LM Studio shrink the footprint): + +| Your hardware | Suggested model | Notes | +|---------------|-----------------|-------| +| CPU-only, or ≤8 GB VRAM / 16 GB RAM | **Qwen2.5-3B-Instruct** (Q4) | Fast and light; fine for scoring + PII spans | +| 8–16 GB VRAM | **Qwen2.5-7B-Instruct** (Q4/Q5) | Recommended balance of quality and speed | +| 24 GB+ VRAM | **Qwen2.5-14B-Instruct** | Best judgment on tricky/ambiguous cases | + +Tiny machine? **Qwen2.5-1.5B-Instruct** or **Llama-3.2-3B-Instruct** also work, with slightly noisier results. Any OpenAI-compatible model will do — these are just sensible starting points. + +## Privacy & Sensitive Data + +Fine-tuning on real chat history may unintentionally encode personal identifiers, confidential conversations, or sensitive content. + +> **Always review and sanitize your dataset before training.** You are responsible for compliance with privacy laws and personal data protection. + +### Automated sensitive-data scan + +To make that review practical, ingestion runs a **regex-based scan** over the built conversations by default. It is **non-destructive** — it only flags and warns, writing `data/redaction_report.json` (with masked previews) and printing a summary so you can decide what to do: + +``` +[redactor] WARNING: 3 potential sensitive item(s) detected across 2 conversations: + EMAIL 2 hit(s) in 2 conversation(s) [medium] + CARD_NUMBER 1 hit(s) in 1 conversation(s) [high] + API_KEY 1 hit(s) in 1 conversation(s) [high] +``` + +Detection works everywhere out of the box. **Universal detectors** — email, payment cards (checksum-validated), IP/MAC addresses, API keys and private keys — aren't tied to any country and always run. On top of those, optional **locale packs** add country-specific identifiers (national IDs, local phone/postal formats). + +Once you've reviewed the report, act on it: + +```bash +python -m ingest --source telegram --redact replace # swap spans for [CATEGORY] +python -m ingest --source telegram --redact drop # drop flagged conversations +python -m ingest --source telegram --skip-redact-scan # opt out entirely ``` -Set `DIALOGSMITH_LLM_VALIDATE=false` to skip validation entirely (no API calls). +### Add coverage for your country + +Locale packs are built to be community-contributed: each is a single drop-in module under [`ingest/redaction/`](ingest/redaction/), needing no changes to the scanner or pipeline. Adding one is three steps: + +1. Copy an existing pack to `ingest/redaction/.py` (your ISO country code). +2. Register detectors with `make(...)` and `locale=""`. Back each pattern with a checksum/validator where the identifier has one — that precision is what keeps the report trustworthy instead of noisy. +3. Import your module in `ingest/redaction/__init__.py`. + +Singapore ships as the worked reference ([`sg.py`](ingest/redaction/sg.py): national ID with checksum, local phone, postal code) — but the recipe is the same for any country, and **PRs for new locales are welcome**. Choose which packs run with `--redact-locales` (universal detectors always run regardless). + +### LLM-assisted redaction + +Regex can't catch everything (names, context-dependent secrets). With `--llm-redact`, an LLM additionally flags such spans into the **same report and the same `--redact` step** — it points at verbatim spans, never rewriting your text. To protect your data it **prefers a local endpoint**: set `LLM_API_BASE_URL` to a local OpenAI-compatible server; without one it refuses to use a hosted API unless you pass `--allow-cloud-redaction`. + +```bash +LLM_API_BASE_URL=http://localhost:8000/v1 LLM_MODEL=Qwen/Qwen2.5-7B-Instruct \ + python -m ingest --source telegram --llm-redact --redact replace +``` + +### Keeping your data out of git + +Your chat export and any generated datasets are ignored by `.gitignore` (`result.json`, `*.json`, `*.jsonl`, `DataExport*/`, `*.session`, `.env`, plus Telegram media/contacts such as `*.vcard`, `*.tgs`, `*.webp`, `*.ogg`/`*.oga`). Generic media (`.jpg`, `.mp4`, …) lives inside `DataExport*/`, which is ignored wholesale. As an extra safeguard, a pre-commit hook refuses to commit these files even if they are force-added. Enable it once per clone: + +```bash +git config core.hooksPath hooks +``` + +To deliberately commit a blocked file, bypass the hook with `git commit --no-verify`. + +## Intended Use & Responsible Use + +Doppelganger is a **personal, educational project** — built for individuals to experiment with fine-tuning on **their own** chat history, for fun and learning. It is not a product, and it is **not** intended for profiling or surveilling other people, or for any commercial or deceptive use. + +If you use it, please: + +- **Use your own data.** Train on chats you're a participant in. Group chats include other people's messages — be considerate, and don't publish models trained on them. +- **Keep it local.** Don't publish the dataset, the trained adapter, or merged checkpoints — they can leak the conversations they were trained on. +- **Don't impersonate or deceive.** Never present generated text as something a real person actually said or wrote. +- **Respect the law.** You are responsible for complying with the privacy and data-protection laws in your jurisdiction. + +In short: it's a toy for exploring how *you* write — please keep it that way. ## Fine-Tune Your Model (LoRA) -Training is configured by [`configs/train_lora.yaml`](configs/train_lora.yaml), which defaults to -**Qwen1.5-1.8B-Chat** and the `chat_sharegpt` dataset registered in -[`configs/dataset_info.json`](configs/dataset_info.json). Activate your venv, then run: +Training is configured by [`configs/train_lora.yaml`](configs/train_lora.yaml), which defaults to **Qwen1.5-1.8B-Chat** and the `chat_sharegpt` dataset registered in [`configs/dataset_info.json`](configs/dataset_info.json). Activate your venv, then run: ```bash llamafactory-cli train configs/train_lora.yaml @@ -139,16 +253,27 @@ llamafactory-cli train configs/train_lora.yaml Edit `configs/train_lora.yaml`: -| Field | Description | -| ---------------------- | -------------------------------------------------------- | -| `model_name_or_path` | Hugging Face model ID or local model path | -| `template` | Prompt template type (e.g., `qwen`, `chatml`, `default`) | -| `lora_target` | LoRA target modules (`all` works across architectures) | -| `output_dir` | Destination to save the LoRA checkpoints | +| Field | Description | +|-------|-------------| +| `model_name_or_path` | Hugging Face model ID or local model path | +| `template` | Prompt template type (e.g. `qwen`, `chatml`, `default`) | +| `lora_target` | LoRA target modules (`all` works across architectures) | +| `output_dir` | Destination to save the LoRA checkpoints | + +For example, to use `mistralai/Mistral-7B-Instruct-v0.2`, set `model_name_or_path` accordingly and `template: chatml`. Refer to the [LLaMA-Factory model table](https://github.com/hiyouga/LLaMA-Factory#supported-models) for recommended values. + +> **Note**: Training masks the loss to your own (assistant) turns — `train_on_prompt: false`. That's why `--multi-speaker` labels on the human side are safe: the model reads them as context but never learns to produce them. -For example, to use `mistralai/Mistral-7B-Instruct-v0.2`, set `model_name_or_path` accordingly and -`template: chatml`. Refer to the -[LLaMA-Factory model table](https://github.com/hiyouga/LLaMA-Factory#supported-models) for recommended values. +#### Keep personal tweaks out of git + +The configs above are committed defaults — editing them in place shows up in `git status` and risks committing your machine-specific model/hyperparameters. To customize **without touching tracked files**, copy a config to a `*.local.yaml` name and edit that instead. Any `configs/*.local.yaml` is gitignored: + +```bash +cp configs/train_lora.yaml configs/train_lora.local.yaml # edit model, batch size, etc. +llamafactory-cli train configs/train_lora.local.yaml +``` + +The same works for `export_lora.local.yaml`. Your overrides stay local; the repo's defaults stay clean. ### Resume training @@ -176,7 +301,7 @@ llamafactory-cli chat \ Update `--template` to match the one used during training. -## Activating the environment later +## Activating the Environment Later After running setup once, reactivate the venv in future sessions before running any commands: @@ -185,21 +310,52 @@ source venv/bin/activate # Linux / macOS venv\Scripts\activate # Windows (Command Prompt) ``` -## Running the tests +## Running the Tests -The ingestion pipeline (parsing, sessionizing, turn-merging, ShareGPT formatting) is covered by a -fast unit suite — no GPU, network, or API key required: +The ingestion pipeline (parsing, sessionizing, turn-merging, reply-threading, sensitive-data detection, ShareGPT formatting) is covered by a fast unit suite — no GPU, network, or API key required: ```bash python -m unittest discover -s tests -t . ``` -It runs in well under a second and locks in the conversion behaviour, so you can verify a change -without running the full pipeline. +It runs in well under a second and locks in the conversion behaviour, so you can verify a change without running the full pipeline. + +## Legacy Workflow + +The pre-refactor, Windows-only workflow (which cloned LLaMA-Factory at HEAD) is preserved at the [`v0.1.0`](https://github.com/NotYuSheng/Doppelganger/releases/tag/v0.1.0) tag. The old `scripts/telegram_extract.py` and `scripts/convert_to_sharegpt.py` shims have been removed — use `python -m ingest` instead. + +## Roadmap & Vision + +Doppelganger is as much a **learning sandbox** as a tool: the aim is to explore the *full* AI toolbox for capturing how a person communicates, and to find what actually moves the needle on *"does this sound like me?"*. Today that's LoRA fine-tuning — everything below is exploratory (see the issue tracker for the live backlog). + +**Shaping the model** — pre-training · fine-tuning · alignment/DPO · continual learning · synthetic data / self-instruct · multi-LoRA personas & merging · distillation to on-device · PEFT comparison + +**Giving it context & memory** — RAG · long-term memory + reflection · relationship/knowledge graph · style embeddings / user-conditioning · persona-prompt quiz · MCP + +**Multimodal** — voice cloning, TTS/STT · stickers / emoji / memes + +**Making it act** — agentic doppelganger · multi-agent & self-play · proactive / initiative modeling + +**Inference-time control** — activation steering / control vectors · prompt optimization (DSPy) + +**Keeping it safe & honest** — guardrails · redaction (shipped) + offline NER · differential-privacy training · machine unlearning · memorization audits / canaries / watermarking / federated + +**Knowing if it works** — evaluation, "does it sound like me?" · interpretability, "what did it learn about me?" + +**More data & coverage** — more chat sources: WhatsApp, Discord, … · wider locale detector packs + +> This is an experimental, for-fun project — the roadmap is a wishlist of things to explore, not a commitment. + +## Star History + + + + + + Star History Chart + + -## Legacy workflow +## License -The pre-refactor, Windows-only workflow (which cloned LLaMA-Factory at HEAD) is preserved at the -[`v0.1.0`](https://github.com/NotYuSheng/Doppelganger/releases/tag/v0.1.0) tag. The old -`scripts/telegram_extract.py` and `scripts/convert_to_sharegpt.py` still work as thin deprecated -wrappers around `python -m ingest`, but will be removed in a future release. +This project is licensed under the MIT License. See [LICENSE](LICENSE) for details. diff --git a/configs/train_lora.yaml b/configs/train_lora.yaml index d1499e5..ff4db98 100644 --- a/configs/train_lora.yaml +++ b/configs/train_lora.yaml @@ -33,6 +33,10 @@ plot_loss: true overwrite_output_dir: true ### train +# Loss is computed only on your (assistant/"gpt") turns; human turns are masked. +# This is the SFT default, set explicitly here so --multi-speaker speaker labels +# on the human side are safe — they condition the model but are never generated. +train_on_prompt: false per_device_train_batch_size: 2 gradient_accumulation_steps: 4 learning_rate: 5.0e-5 diff --git a/demo/README.md b/demo/README.md new file mode 100644 index 0000000..530dede --- /dev/null +++ b/demo/README.md @@ -0,0 +1,31 @@ +# demo/ + +Assets and the (dev-only) scripts used to generate the README banner/GIF. None of +this is needed to run Doppelganger — it's tooling for regenerating the visuals. + +| File | What it is | +|------|------------| +| `parrot-mirror.jpg` | Source image for the mascot | +| `mascot.txt` | The parrot converted to braille ASCII (committed art) | +| `sample_export.json` | **Synthetic** Telegram export used by the demo (safe, fake PII) | +| `demo.gif` | The README demo (ingest + sensitive-data scan) | +| `img2ascii.py` | Convert an image to ASCII (brightness ramp) | +| `build_final.py` | Rebuild `ingest/banner.py` and `demo/demo.gif` | + +## Regenerating + +These scripts need extra dev dependencies that the app itself does **not** require: + +```bash +pip install pillow pyfiglet # img2ascii.py / build_final.py +# plus the asciinema 'agg' renderer (https://github.com/asciinema/agg): +# cargo install --git https://github.com/asciinema/agg +# or download a release binary and set: export AGG=/path/to/agg +``` + +Then: + +```bash +python demo/img2ascii.py parrot-mirror.jpg 72 # preview the mascot conversion +python demo/build_final.py # rewrite the banner + demo.gif +``` diff --git a/demo/build_final.py b/demo/build_final.py new file mode 100644 index 0000000..33014df --- /dev/null +++ b/demo/build_final.py @@ -0,0 +1,101 @@ +#!/usr/bin/env python3 +"""Build the final banner (ingest/banner.py) and demo GIF (demo/demo.gif). + +Layout: parrot (left) + ansi_shadow "Doppel"/"ganger" (right), amber wordmark, +tagline centered beneath. Renders the GIF with agg's dracula theme. +""" +import json +import os +import subprocess + +import pyfiglet + +ROOT = os.path.dirname(os.path.dirname(os.path.abspath(__file__))) +PY = os.path.join(ROOT, "venv", "bin", "python") +AGG = os.environ.get("AGG", "agg") # asciinema agg on PATH; override with $AGG +GAP = " " +TAG = "fine-tune an LLM to write like you" +CMD = "python -m ingest --source telegram --input demo/sample_export.json" +AMBER, RESET = "\x1b[1;38;2;242;176;76m", "\x1b[0m" + +with open(os.path.join(ROOT, "demo/mascot.txt"), encoding="utf-8") as _f: + parrot = _f.read().rstrip("\n").split("\n") +PW = max(len(l) for l in parrot) + + +def _fig(t): + ls = [l.rstrip() for l in pyfiglet.figlet_format(t, font="ansi_shadow", width=200).rstrip("\n").split("\n")] + while ls and not ls[-1].strip(): ls.pop() + while ls and not ls[0].strip(): ls.pop(0) + return ls + + +word = _fig("Doppel") + _fig("ganger") +TOP = (len(parrot) - len(word)) // 2 +TOTAL_W = PW + len(GAP) + max(len(l) for l in word) + + +def rows(on, off): + r = [] + for i, pl in enumerate(parrot): + wl = word[i - TOP] if 0 <= i - TOP < len(word) else "" + wl = f"{on}{wl}{off}" if wl else "" + r.append((pl.ljust(PW) + GAP + wl).rstrip()) + r.append("") + r.append(TAG.center(TOTAL_W).rstrip()) # tagline centred under the whole logo + return r + + +def write_banner_module(): + body = "\n".join(rows("", "")) # sentinels; colourised at runtime + mod = ( + '"""ASCII startup banner: a parrot in a mirror (it mimics your voice; the\n' + 'mirror is the doppelganger) beside the wordmark. The wordmark is amber via\n' + 'truecolor ANSI. Regenerate via demo/build_final.py.\n' + 'Set DOPPELGANGER_NO_BANNER=1 to silence it."""\n\n' + 'import os\n\n' + '_AMBER = "\\x1b[1;38;2;242;176;76m" # truecolor amber\n' + '_RESET = "\\x1b[0m"\n\n' + '_BANNER = r"""\n' + body + '\n"""\n\n\n' + 'def print_banner() -> None:\n' + ' if os.environ.get("DOPPELGANGER_NO_BANNER"):\n' + ' return\n' + ' print(_BANNER.replace("", _AMBER).replace("", _RESET) + "\\n")\n' + ) + open(os.path.join(ROOT, "ingest/banner.py"), "w", encoding="utf-8").write(mod) + + +def render_gif(): + env = dict(os.environ, LLM_VALIDATE="false", DOPPELGANGER_NO_BANNER="1") + out = subprocess.run([PY, *CMD.split()[1:]], cwd=ROOT, env=env, + capture_output=True, text=True, check=True) + report = ((out.stdout or "") + (out.stderr or "")).split("\n") + + events, t = [], 0.0 + def emit(d, dt): + nonlocal t + t += dt + events.append([round(t, 3), "o", d]) + emit("\x1b[32m$\x1b[0m ", 0.3) + for ch in CMD: + emit(ch, 0.026) + emit("\r\n", 0.5) + for line in rows(AMBER, RESET) + report: + emit(line + "\r\n", 0.05) + emit("\x1b[32m$\x1b[0m ", 1.6) + + cast = os.path.join(ROOT, "demo/demo.cast") + with open(cast, "w", encoding="utf-8") as f: + f.write(json.dumps({"version": 2, "width": 94, "height": 34}) + "\n") + for ev in events: + f.write(json.dumps(ev, ensure_ascii=False) + "\n") + subprocess.run([AGG, "--font-size", "18", "--theme", "dracula", cast, + os.path.join(ROOT, "demo/demo.gif")], + stdout=subprocess.DEVNULL, stderr=subprocess.PIPE, text=True, check=True) + + +if __name__ == "__main__": + write_banner_module() + render_gif() + print("=== layout preview ===") + print("\n".join(rows("", ""))) diff --git a/demo/demo.gif b/demo/demo.gif new file mode 100644 index 0000000..799a218 Binary files /dev/null and b/demo/demo.gif differ diff --git a/demo/img2ascii.py b/demo/img2ascii.py new file mode 100644 index 0000000..d51a6b6 --- /dev/null +++ b/demo/img2ascii.py @@ -0,0 +1,40 @@ +#!/usr/bin/env python3 +"""Convert an image to ASCII art via a brightness ramp. + +Usage: python demo/img2ascii.py PATH [cols] [--invert] +Transparent images are composited onto white first, so a dark subject on a +transparent background (e.g. an OpenMoji black glyph) renders as the dense end +of the ramp. +""" +import sys +from PIL import Image + +RAMP = " .:-=+*#%@" + + +def to_ascii(path: str, cols: int = 42, invert: bool = False) -> str: + img = Image.open(path) + if img.mode in ("RGBA", "LA") or (img.mode == "P" and "transparency" in img.info): + bg = Image.new("RGBA", img.size, (255, 255, 255, 255)) + img = Image.alpha_composite(bg, img.convert("RGBA")) + g = img.convert("L") + rows = max(1, int(cols * (g.height / g.width) * 0.50)) + g = g.resize((cols, rows)) + px = g.load() + ramp = RAMP[::-1] if invert else RAMP + lines = [] + for y in range(rows): + lines.append( + "".join( + ramp[int((255 - px[x, y]) / 255 * (len(ramp) - 1))] for x in range(cols) + ).rstrip() + ) + return "\n".join(lines) + + +if __name__ == "__main__": + args = [a for a in sys.argv[1:] if a != "--invert"] + if not args: + print("Usage: python demo/img2ascii.py PATH [cols] [--invert]", file=sys.stderr) + raise SystemExit(2) + print(to_ascii(args[0], int(args[1]) if len(args) > 1 else 42, "--invert" in sys.argv)) diff --git a/demo/mascot.txt b/demo/mascot.txt new file mode 100644 index 0000000..dc8fbf0 --- /dev/null +++ b/demo/mascot.txt @@ -0,0 +1,17 @@ +⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⡀ +⠀⠀⠀⠀⠀⠀⠀⠀⣀⣤⠶⠖⠛⠛⠋⠉⣿⣿⣿⣿⣷⣶⣤⣄⡀ +⠀⠀⠀⠀⠀⣠⡶⠛⠉⠀⠀⠀⠀⠀⠀⠀⣿⣿⣿⣿⣿⣿⣿⣿⣿⣷⣄⡀ +⠀⠀⠀⣠⡾⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣆ +⠀⠀⣴⢋⣤⣴⣶⣦⣤⡀⠀⠀⠀⠀⠀⠀⣿⣿⣿⣿⣿⣿⡿⠟⠋⠉⠉⠙⠻⢷⡀ +⠀⣸⣿⣿⣿⠿⠛⠻⣿⣿⣦⠀⠀⢀⣀⣤⣉⡙⠻⢿⣿⠏⠀⣠⠶⠚⠓⢶⣄⠀⠁ +⠀⣿⣿⣿⠁⠀⠸⠿⠊⣿⣿⠂⣴⠟⠁⠀⣿⣿⣷⣄⠙⠀⢸⠇⠀⠐⠿⠇⢹⡆ +⢸⣿⣿⣿⡀⠄⠀⠀⢀⣿⠇⣼⠃⠸⠁⠀⣿⣿⣄⣿⣇⠀⠸⣇⠠⠀⠀⠀⣼⠇ +⢸⡿⢿⣿⣿⣷⣶⣾⣿⣿⢰⡏⠀⠀⠀⠀⣿⣿⣿⣿⣿⡀⠀⠈⠛⠶⠶⠛⠁⠀⢠⡇ +⢸⡇⠈⠛⠿⣿⣿⣿⡿⠟⢸⣇⠀⠀⠀⠀⣿⣿⣿⣿⡿⠇⣀⡀⠀⠀⠀⢀⣠⣴⣿⡇ +⢸⡇⠀⠀⠀⠀⠀⠀⠀⠀⢰⡌⢳⣄⠀⠀⣿⣿⣿⠋⣠⠀⣿⣿⣿⣿⣿⣿⣿⣿⣿⡇ +⢸⡇⠀⠀⠀⠀⠀⠀⠀⠀⠈⢷⡀⢹⡆⠀⣿⣿⠃⣼⠏⣰⣿⣿⣿⣿⣿⣿⣿⣿⣿⡇ +⢸⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠳⣄⢻⡄⣿⡏⠐⢋⣴⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⡇ +⢸⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢷⡟⢠⣾⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⡇ +⢸⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣴⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⡇ +⢸⣇⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⡇ +⠈⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠁ diff --git a/demo/parrot-mirror.jpg b/demo/parrot-mirror.jpg new file mode 100644 index 0000000..17257be Binary files /dev/null and b/demo/parrot-mirror.jpg differ diff --git a/demo/sample_export.json b/demo/sample_export.json new file mode 100644 index 0000000..653b372 --- /dev/null +++ b/demo/sample_export.json @@ -0,0 +1,22 @@ +{ + "personal_information": { "first_name": "Alex", "last_name": "Tan" }, + "chats": { + "list": [ + { + "id": 4815162342, + "messages": [ + { "type": "message", "id": 1, "from": "Jamie", "date_unixtime": "1700000000", "text": "yo did you book the airbnb for the trip?" }, + { "type": "message", "id": 2, "from": "Alex Tan", "date_unixtime": "1700000035", "text": "yeah just sorted it! i'll forward you the invoice" }, + { "type": "message", "id": 3, "from": "Jamie", "date_unixtime": "1700000070", "text": "nice send it to jamie.wong@gmail.com" }, + { "type": "message", "id": 4, "from": "Alex Tan", "date_unixtime": "1700000110", "text": "done. they need your details for check-in too" }, + { "type": "message", "id": 5, "from": "Jamie", "date_unixtime": "1700000150", "text": "ok my number is 8123 4567 and nric S1234567D" }, + { "type": "message", "id": 6, "from": "Alex Tan", "date_unixtime": "1700000190", "text": "got it. i already paid the deposit with card 4111 1111 1111 1111" }, + { "type": "message", "id": 7, "from": "Jamie", "date_unixtime": "1700000215", "text": "lmao did you just type your full card number" }, + { "type": "message", "id": 8, "from": "Alex Tan", "date_unixtime": "1700000245", "text": "haha oops. anyway i'll send the confirmation to alex.tan@example.com" }, + { "type": "message", "id": 9, "from": "Jamie", "date_unixtime": "1700000275", "text": "perfect, see you next week!" }, + { "type": "message", "id": 10, "from": "Alex Tan", "date_unixtime": "1700000300", "text": "see ya, can't wait" } + ] + } + ] + } +} diff --git a/example.env b/example.env new file mode 100644 index 0000000..3caa359 --- /dev/null +++ b/example.env @@ -0,0 +1,22 @@ +# Copy to .env and fill in. .env is gitignored — never commit your keys. +# Every value here is OPTIONAL; with none set, ingestion still runs (the LLM +# features just stay off). + +# ── Optional LLM features (quality auditor + LLM redaction) ─────────────────── +# The CORE pipeline (parse -> dataset + regex sensitive-data scan) needs NONE of +# this and runs with no setup. Uncomment below to ALSO enable the LLM auditor / +# redaction. +# +# Run a LOCAL OpenAI-compatible server so your chat data never leaves your machine +# (vLLM, LM Studio, llama.cpp). Serve an open model, then uncomment: +# +# LLM_VALIDATE=true +# LLM_API_BASE_URL=http://localhost:8000/v1 # vLLM (LM Studio uses :1234/v1) +# LLM_MODEL=Qwen/Qwen2.5-7B-Instruct # the model your server serves +# LLM_API_KEY=local # local servers accept any value + +# ── Optional: Hugging Face ──────────────────────────────────────────────────── +# Only needed to download GATED models during training (e.g. Gemma). The default +# Qwen model in configs/train_lora.yaml is open and needs no token. Read by the +# training stack (huggingface_hub), not by this repo's ingestion code. +# HF_TOKEN= diff --git a/ingest/adapters/telegram.py b/ingest/adapters/telegram.py index dabb2b8..323c2f0 100644 --- a/ingest/adapters/telegram.py +++ b/ingest/adapters/telegram.py @@ -64,7 +64,12 @@ def parse( for msg in chat.get("messages", []): if not _is_valid(msg): continue - sender = msg.get("from") + # "from" can be missing/None (e.g. anonymous channel posts); fall + # back to a label so sender_id stays a str (and multi-speaker + # mode doesn't emit a "None: " prefix). + sender = msg.get("from") or "Unknown" + reply_to = msg.get("reply_to_message_id") + msg_id = msg.get("id") messages.append( NormalizedMessage( chat_id=chat_id, @@ -72,6 +77,8 @@ def parse( sender_id=sender, sender_is_self=(sender == self_name), text=_get_text(msg), + message_id=str(msg_id) if msg_id is not None else None, + reply_to_id=str(reply_to) if reply_to is not None else None, ) ) return messages diff --git a/ingest/banner.py b/ingest/banner.py new file mode 100644 index 0000000..8c93b5b --- /dev/null +++ b/ingest/banner.py @@ -0,0 +1,37 @@ +"""ASCII startup banner: a parrot in a mirror (it mimics your voice; the +mirror is the doppelganger) beside the wordmark. The wordmark is amber via +truecolor ANSI. Regenerate via demo/build_final.py. +Set DOPPELGANGER_NO_BANNER=1 to silence it.""" + +import os + +_AMBER = "\x1b[1;38;2;242;176;76m" # truecolor amber +_RESET = "\x1b[0m" + +_BANNER = r""" +⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣀⣀⡀ +⠀⠀⠀⠀⠀⠀⠀⠀⣀⣤⠶⠖⠛⠛⠋⠉⣿⣿⣿⣿⣷⣶⣤⣄⡀ +⠀⠀⠀⠀⠀⣠⡶⠛⠉⠀⠀⠀⠀⠀⠀⠀⣿⣿⣿⣿⣿⣿⣿⣿⣿⣷⣄⡀ ██████╗ ██████╗ ██████╗ ██████╗ ███████╗██╗ +⠀⠀⠀⣠⡾⠋⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣆ ██╔══██╗██╔═══██╗██╔══██╗██╔══██╗██╔════╝██║ +⠀⠀⣴⢋⣤⣴⣶⣦⣤⡀⠀⠀⠀⠀⠀⠀⣿⣿⣿⣿⣿⣿⡿⠟⠋⠉⠉⠙⠻⢷⡀ ██║ ██║██║ ██║██████╔╝██████╔╝█████╗ ██║ +⠀⣸⣿⣿⣿⠿⠛⠻⣿⣿⣦⠀⠀⢀⣀⣤⣉⡙⠻⢿⣿⠏⠀⣠⠶⠚⠓⢶⣄⠀⠁ ██║ ██║██║ ██║██╔═══╝ ██╔═══╝ ██╔══╝ ██║ +⠀⣿⣿⣿⠁⠀⠸⠿⠊⣿⣿⠂⣴⠟⠁⠀⣿⣿⣷⣄⠙⠀⢸⠇⠀⠐⠿⠇⢹⡆ ██████╔╝╚██████╔╝██║ ██║ ███████╗███████╗ +⢸⣿⣿⣿⡀⠄⠀⠀⢀⣿⠇⣼⠃⠸⠁⠀⣿⣿⣄⣿⣇⠀⠸⣇⠠⠀⠀⠀⣼⠇ ╚═════╝ ╚═════╝ ╚═╝ ╚═╝ ╚══════╝╚══════╝ +⢸⡿⢿⣿⣿⣷⣶⣾⣿⣿⢰⡏⠀⠀⠀⠀⣿⣿⣿⣿⣿⡀⠀⠈⠛⠶⠶⠛⠁⠀⢠⡇ ██████╗ █████╗ ███╗ ██╗ ██████╗ ███████╗██████╗ +⢸⡇⠈⠛⠿⣿⣿⣿⡿⠟⢸⣇⠀⠀⠀⠀⣿⣿⣿⣿⡿⠇⣀⡀⠀⠀⠀⢀⣠⣴⣿⡇ ██╔════╝ ██╔══██╗████╗ ██║██╔════╝ ██╔════╝██╔══██╗ +⢸⡇⠀⠀⠀⠀⠀⠀⠀⠀⢰⡌⢳⣄⠀⠀⣿⣿⣿⠋⣠⠀⣿⣿⣿⣿⣿⣿⣿⣿⣿⡇ ██║ ███╗███████║██╔██╗ ██║██║ ███╗█████╗ ██████╔╝ +⢸⡇⠀⠀⠀⠀⠀⠀⠀⠀⠈⢷⡀⢹⡆⠀⣿⣿⠃⣼⠏⣰⣿⣿⣿⣿⣿⣿⣿⣿⣿⡇ ██║ ██║██╔══██║██║╚██╗██║██║ ██║██╔══╝ ██╔══██╗ +⢸⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⠳⣄⢻⡄⣿⡏⠐⢋⣴⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⡇ ╚██████╔╝██║ ██║██║ ╚████║╚██████╔╝███████╗██║ ██║ +⢸⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠈⢷⡟⢠⣾⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⡇ ╚═════╝ ╚═╝ ╚═╝╚═╝ ╚═══╝ ╚═════╝ ╚══════╝╚═╝ ╚═╝ +⢸⡇⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⠀⣴⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⡇ +⢸⣇⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣀⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⣿⡇ +⠈⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠉⠁ + + fine-tune an LLM to write like you +""" + + +def print_banner() -> None: + if os.environ.get("DOPPELGANGER_NO_BANNER"): + return + print(_BANNER.replace("", _AMBER).replace("", _RESET) + "\n") diff --git a/ingest/cli.py b/ingest/cli.py index b666f5c..0ead228 100644 --- a/ingest/cli.py +++ b/ingest/cli.py @@ -10,7 +10,9 @@ import os import sys -from ingest import core, sharegpt +import os.path + +from ingest import core, redactor, sharegpt from ingest.adapters import available_sources, get_adapter from ingest.validator import validate_samples @@ -38,6 +40,36 @@ def _load_dotenv(path: str = ".env") -> None: os.environ[key] = value +def _run_llm_redaction(samples, allow_cloud: bool): + """Run the optional LLM redaction pass, guarding against accidental cloud use. + + Returns a (possibly empty) list of LLM findings. Prefers a local endpoint; + if none is configured and cloud use wasn't explicitly allowed, it warns and + skips rather than silently shipping chat data to a third party. + """ + from ingest import llm + + if not llm.is_local() and not allow_cloud: + print( + "[redactor] --llm-redact set but no local endpoint configured. " + "Refusing to send chat data to a hosted API by default. Set " + f"{llm.BASE_URL_ENV} to a local OpenAI-compatible server (Ollama, " + "vLLM, LM Studio, ...), or pass --allow-cloud-redaction to override. " + "Skipping LLM pass." + ) + return [] + + try: + client = llm.get_client() + except (ImportError, EnvironmentError) as e: + print(f"[redactor] LLM redaction unavailable: {e}. Skipping LLM pass.") + return [] + + model = llm.model() + print(f"[redactor] LLM redaction scan via {model} ({llm.endpoint_label()})...") + return redactor.llm_scan_samples(samples, client, model) + + def build_parser() -> argparse.ArgumentParser: parser = argparse.ArgumentParser( prog="python -m ingest", @@ -85,6 +117,58 @@ def build_parser() -> argparse.ArgumentParser: help=f"Max seconds between same-sender messages to merge into one turn " f"(default: {core.DEFAULT_MESSAGE_CHAIN}).", ) + parser.add_argument( + "--multi-speaker", + action="store_true", + help="In group chats, keep individual senders and label each user turn " + "with their name (e.g. 'Bob: ...'). Your own turns are never labelled. " + "Default collapses the other side into one speaker.", + ) + parser.add_argument( + "--redact", + choices=["off", "replace", "drop"], + default="off", + help="What to do with detected sensitive data. 'off' (default) only " + "scans and writes a report. 'replace' swaps spans for [CATEGORY] " + "placeholders; 'drop' removes conversations containing detections.", + ) + parser.add_argument( + "--redact-locales", + default="SG", + help="Comma-separated locales for sensitive-data detection (universal " + "patterns always run). Default: SG.", + ) + parser.add_argument( + "--skip-redact-scan", + action="store_true", + help="Skip the sensitive-data scan/report entirely.", + ) + parser.add_argument( + "--llm-redact", + action="store_true", + help="Additionally use an LLM to flag context-dependent sensitive data " + "(names, secrets regex misses). Prefers a local endpoint: set " + "LLM_API_BASE_URL, or pass --allow-cloud-redaction to use a hosted API " + "(which sends chat text to a third party).", + ) + parser.add_argument( + "--allow-cloud-redaction", + action="store_true", + help="Permit LLM redaction against a hosted API when no local " + "LLM_API_BASE_URL is configured.", + ) + parser.add_argument( + "--no-audit", + action="store_true", + help="Master off-switch: skip ALL auditing — the regex sensitive-data " + "scan and the LLM quality validation. Just build the dataset.", + ) + parser.add_argument( + "--skip-validation", + action="store_true", + help="Skip only the LLM quality validation (the regex scan still runs " + "unless --skip-redact-scan / --no-audit is also given).", + ) return parser @@ -92,6 +176,9 @@ def main(argv=None) -> int: _load_dotenv() args = build_parser().parse_args(argv) + from ingest import banner + banner.print_banner() + try: adapter = get_adapter(args.source) except ValueError as e: @@ -110,10 +197,46 @@ def main(argv=None) -> int: messages, conversation_gap=args.conversation_gap, message_chain=args.message_chain, + multi_speaker=args.multi_speaker, ) print(f"Extracted {len(samples)} conversation samples.") - samples = validate_samples(samples) + # --no-audit is the master off-switch; the granular flags disable one half. + skip_scan = args.no_audit or args.skip_redact_scan + skip_validation = args.no_audit or args.skip_validation + if args.no_audit: + print("[audit] All auditing disabled (--no-audit) — building dataset as-is.") + + locales = [s.strip() for s in args.redact_locales.split(",") if s.strip()] + llm_findings = [] + if not skip_scan: + report = redactor.scan_samples(samples, locales=locales) + if args.llm_redact: + llm_findings = _run_llm_redaction(samples, args.allow_cloud_redaction) + redactor.merge_llm_findings(report, llm_findings) + report_path = os.path.join(os.path.dirname(output) or ".", "redaction_report.json") + redactor.write_report(report, report_path) + redactor.print_summary(report, report_path, mode=args.redact) + + # --redact is an explicit request, so honour it even when the scan/report was + # skipped — otherwise the dataset would silently keep sensitive data. + if args.redact != "off": + if skip_scan: + print( + f"[redactor] Scan skipped, but --redact {args.redact} was requested — " + "applying regex redaction (note: --llm-redact needs the scan)." + ) + before = len(samples) + samples = redactor.apply( + samples, args.redact, locales=locales, llm_findings=llm_findings + ) + print( + f"[redactor] Applied --redact {args.redact}: " + f"{before} -> {len(samples)} samples." + ) + + if not skip_validation: + samples = validate_samples(samples) if args.format == "sharegpt": written = sharegpt.write_sharegpt(samples, output) diff --git a/ingest/core.py b/ingest/core.py index da50061..81c2e02 100644 --- a/ingest/core.py +++ b/ingest/core.py @@ -56,6 +56,61 @@ def _split_into_conversations( return conversations +def _merge_by_reply( + conversations: List[List[NormalizedMessage]], +) -> List[List[NormalizedMessage]]: + """Stitch back conversations that a silence gap split but a reply connects. + + A time gap is a guess at where one conversation ends. An explicit reply link + is ground truth: if a message replies to one in an earlier (same-chat) + conversation, they belong together. We union such conversations and re-sort + each merged group chronologically. + + When no message carries reply metadata (``message_id``/``reply_to_id`` all + ``None``), there is nothing to union and the input is returned unchanged — so + sources without reply data keep the pure time-based behaviour. + """ + n = len(conversations) + if n <= 1: + return conversations + + id_to_conv = { + m.message_id: ci + for ci, conv in enumerate(conversations) + for m in conv + if m.message_id is not None + } + if not id_to_conv: + return conversations + + parent = list(range(n)) + + def find(x: int) -> int: + while parent[x] != x: + parent[x] = parent[parent[x]] + x = parent[x] + return x + + def union(a: int, b: int) -> None: + ra, rb = find(a), find(b) + if ra != rb: + parent[max(ra, rb)] = min(ra, rb) + + for ci, conv in enumerate(conversations): + for m in conv: + target = id_to_conv.get(m.reply_to_id) if m.reply_to_id else None + if target is not None and target != ci: + union(ci, target) + + groups: "Dict[int, List[NormalizedMessage]]" = {} + for ci in range(n): + groups.setdefault(find(ci), []).extend(conversations[ci]) + + # Order merged groups by their earliest message so output stays chronological. + ordered_roots = sorted(groups, key=lambda r: min(m.timestamp for m in groups[r])) + return [sorted(groups[r], key=lambda m: m.timestamp) for r in ordered_roots] + + def _collect_turn( conversation: List[NormalizedMessage], start_idx: int, chain_threshold: int ): @@ -82,35 +137,71 @@ def _collect_turn( return texts, j +def _assemble_turns(raw_turns, multi_speaker: bool) -> Sample: + """Turn ``(sender_id, is_self, text)`` runs into role/text turns. + + Roles: the dataset owner is ``assistant`` (this is what the doppelganger + learns to produce, so it is *never* labelled), everyone else is ``user``. + + Default mode merges adjacent same-role runs, so in a group chat several + people on the "other side" collapse into one ``user`` turn. ``multi_speaker`` + instead keeps each speaker distinct and prefixes ``user`` turns with the + sender (``"Bob: ..."``), only merging consecutive runs from the *same* + sender — preserving who-said-what as conditioning context. + """ + turns: Sample = [] + last_sender = None + + for sender_id, is_self, text in raw_turns: + role = "assistant" if is_self else "user" + + same_role = bool(turns) and turns[-1]["role"] == role + # In multi-speaker mode a user turn only merges with the previous turn + # when it is the same speaker; otherwise distinct speakers stay distinct. + mergeable = same_role and not ( + multi_speaker and role == "user" and last_sender != sender_id + ) + if mergeable: + # Continuation of the same turn — don't repeat the speaker prefix. + turns[-1]["text"] += "\n" + text + else: + value = f"{sender_id}: {text}" if (multi_speaker and role == "user") else text + turns.append({"role": role, "text": value}) + last_sender = sender_id + + return turns + + def build_samples( messages: Iterable[NormalizedMessage], conversation_gap: int = DEFAULT_CONVERSATION_GAP, message_chain: int = DEFAULT_MESSAGE_CHAIN, + multi_speaker: bool = False, ) -> List[Sample]: """Turn normalized messages into multi-turn conversation samples. - Splits each chat into conversations, merges consecutive same-sender messages - into turns, and keeps only conversations containing at least one user turn - and one assistant turn. + Splits each chat into conversations (stitching reply-linked ones back + together), merges consecutive same-sender messages into turns, and keeps + only conversations containing at least one user turn and one assistant turn. + + ``multi_speaker`` preserves and labels individual senders in group chats + (see :func:`_assemble_turns`); the default collapses the other side. """ samples: List[Sample] = [] for chat_messages in _group_by_chat(messages): - for conversation in _split_into_conversations(chat_messages, conversation_gap): - turns: Sample = [] + time_convs = _split_into_conversations(chat_messages, conversation_gap) + for conversation in _merge_by_reply(time_convs): + raw_turns = [] i = 0 while i < len(conversation): texts, next_i = _collect_turn(conversation, i, message_chain) if texts: - role = "assistant" if conversation[i].sender_is_self else "user" - turn_text = "\n".join(texts) - # Merge with previous turn if same role (e.g. gap split a block). - if turns and turns[-1]["role"] == role: - turns[-1]["text"] += "\n" + turn_text - else: - turns.append({"role": role, "text": turn_text}) + m = conversation[i] + raw_turns.append((m.sender_id, m.sender_is_self, "\n".join(texts))) i = next_i + turns = _assemble_turns(raw_turns, multi_speaker) roles = {t["role"] for t in turns} if "user" in roles and "assistant" in roles: samples.append(turns) diff --git a/ingest/llm.py b/ingest/llm.py new file mode 100644 index 0000000..b0af03b --- /dev/null +++ b/ingest/llm.py @@ -0,0 +1,101 @@ +"""Shared OpenAI-compatible LLM client. + +One client for every optional LLM feature (quality validation, LLM redaction). +It speaks the OpenAI Chat Completions API, which is the de-facto standard that +local/self-hosted servers also expose — vLLM, LM Studio, llama.cpp's server, +Ollama, LiteLLM, etc. For privacy, run a LOCAL endpoint so your chat text never +leaves your machine; that is the intended setup for this project. + +Environment variables: + LLM_VALIDATE true/false. Default: enabled when LLM_API_KEY or + LLM_API_BASE_URL is set, disabled otherwise. + LLM_API_BASE_URL Base URL of your local OpenAI-compatible server, e.g. + http://localhost:8000/v1 (vLLM) or http://localhost:1234/v1 + (LM Studio). + LLM_MODEL Model id your server serves — required to use the LLM features + (no default). Use the HF repo id, as vLLM / LM Studio do + (e.g. "Qwen/Qwen2.5-7B-Instruct"). + LLM_API_KEY API key. Local servers usually accept any value. +""" + +import os + +VALIDATE_ENV = "LLM_VALIDATE" +MODEL_ENV = "LLM_MODEL" +BASE_URL_ENV = "LLM_API_BASE_URL" +API_KEY_ENV = "LLM_API_KEY" + + +def base_url() -> str: + return os.environ.get(BASE_URL_ENV, "").strip() + + +def model() -> str: + """The configured model id, or empty string if unset (no default).""" + return os.environ.get(MODEL_ENV, "").strip() + + +def is_local() -> bool: + """True when a custom (presumably local/self-hosted) endpoint is configured.""" + return bool(base_url()) + + +def _api_key() -> str: + return ( + os.environ.get(API_KEY_ENV, "").strip() + or os.environ.get("OPENAI_API_KEY", "").strip() + ) + + +def should_validate() -> bool: + val = os.environ.get(VALIDATE_ENV, "").strip().lower() + if val == "false": + return False + if val == "true": + return True + # Default: enable when there's something to talk to. + return bool(_api_key() or base_url()) + + +def get_client(): + """Build an OpenAI-compatible client. Raises if unusable (caller handles).""" + try: + from openai import OpenAI + except ImportError: + raise ImportError( + "The 'openai' package is required for LLM features. " + "Install it with: pip install openai" + ) + if not model(): + raise EnvironmentError( + f"{MODEL_ENV} is not set. Set it to the model your local server serves " + f"(e.g. Qwen/Qwen2.5-7B-Instruct), or set {VALIDATE_ENV}=false." + ) + url = base_url() + key = _api_key() + if not key: + if url: + key = "not-needed" # local servers ignore it, but the SDK requires a value + else: + raise EnvironmentError( + f"{API_KEY_ENV} is not set. Set it, point {BASE_URL_ENV} at a " + f"local endpoint, or set {VALIDATE_ENV}=false." + ) + kwargs = {"api_key": key} + if url: + kwargs["base_url"] = url + return OpenAI(**kwargs) + + +def endpoint_label() -> str: + return base_url() or "OpenAI API" + + +def chat(client, model_name: str, prompt: str, max_tokens: int = 256) -> str: + """Single-prompt completion; returns the assistant message text.""" + resp = client.chat.completions.create( + model=model_name, + max_tokens=max_tokens, + messages=[{"role": "user", "content": prompt}], + ) + return (resp.choices[0].message.content or "").strip() diff --git a/ingest/message.py b/ingest/message.py index edde67d..0db1df7 100644 --- a/ingest/message.py +++ b/ingest/message.py @@ -1,6 +1,7 @@ """The normalized, source-agnostic message shape shared by the pipeline.""" from dataclasses import dataclass +from typing import Optional @dataclass @@ -24,6 +25,12 @@ class NormalizedMessage: ("you"). Drives the user/assistant role assignment downstream. text: The plain-text message body (already extracted/cleaned by the adapter). Adapters should only emit messages with non-empty text. + message_id: Source-stable id for this message, used to resolve reply + links. ``None`` if the source has no message ids. + reply_to_id: ``message_id`` of the message this one replies to, or + ``None``. Lets the pipeline thread replies instead of relying on + time order alone. Adapters that lack reply data leave both ``None``, + and grouping falls back to its time-based behaviour. """ chat_id: str @@ -31,3 +38,5 @@ class NormalizedMessage: sender_id: str sender_is_self: bool text: str + message_id: Optional[str] = None + reply_to_id: Optional[str] = None diff --git a/ingest/redaction/__init__.py b/ingest/redaction/__init__.py new file mode 100644 index 0000000..90de76a --- /dev/null +++ b/ingest/redaction/__init__.py @@ -0,0 +1,172 @@ +"""Regex-based sensitive-data detection. + +A *detector* is a named, locale-tagged regex (optionally backed by a checksum +validator) that flags one category of sensitive data — an email, a credit card, +a Singapore NRIC, etc. Detectors register themselves at import time via +:func:`register`, exactly like source adapters do, so adding coverage for a new +country is a single drop-in module under ``ingest/redaction/`` — no changes to +the scanner or the pipeline. + +Detection is **non-destructive**: :func:`scan_text` and :func:`scan_samples` +only *report* matches (as :class:`Finding` objects). Whether to redact is the +user's decision, taken later against the audit report. + +Want to add your country? Copy ``sg.py``, swap in your locale's patterns + +checksum validators, and register them. See ``CONTRIBUTING`` notes in ``sg.py``. +""" + +import re +from dataclasses import dataclass +from typing import Callable, Dict, Iterable, List, Optional, Pattern + +UNIVERSAL = "universal" # locale tag for patterns that are the same worldwide + + +@dataclass(frozen=True) +class Detector: + """One category of sensitive data and how to recognise it. + + Attributes: + name: Unique id, e.g. ``"sg_nric"`` or ``"email"``. + category: Human-facing label shown in reports, e.g. ``"NRIC"``. + locale: ``"universal"`` or an ISO 3166-1 alpha-2 code (``"SG"``). + pattern: Compiled regex. Every full match is a candidate. + severity: ``"low" | "medium" | "high"`` — drives the suggested action. + validator: Optional extra check on the matched string (e.g. Luhn, + NRIC checksum). A candidate is only flagged if it returns True. + This is what turns a noisy regex into a high-precision detector. + """ + + name: str + category: str + locale: str + pattern: Pattern + severity: str = "medium" + validator: Optional[Callable[[str], bool]] = None + + +@dataclass(frozen=True) +class Finding: + """A single detected span of sensitive data within one text.""" + + detector: str + category: str + locale: str + severity: str + start: int + end: int + value: str + preview: str # masked, safe to print/log + + +_REGISTRY: "List[Detector]" = [] + + +def register(detector: Detector) -> Detector: + """Register a detector. Duplicate ``name`` is a programming error.""" + if any(d.name == detector.name for d in _REGISTRY): + raise ValueError(f"Duplicate detector name: {detector.name!r}") + _REGISTRY.append(detector) + return detector + + +def make( + name: str, + category: str, + locale: str, + regex: str, + *, + severity: str = "medium", + flags: int = 0, + validator: Optional[Callable[[str], bool]] = None, +) -> Detector: + """Compile a regex and register it as a detector in one call.""" + return register( + Detector( + name=name, + category=category, + locale=locale, + pattern=re.compile(regex, flags), + severity=severity, + validator=validator, + ) + ) + + +def available_locales() -> "List[str]": + return sorted({d.locale for d in _REGISTRY}) + + +def iter_detectors(locales: Optional[Iterable[str]] = None) -> "List[Detector]": + """Detectors for the given locales. ``None`` means all. + + ``UNIVERSAL`` detectors are always included — email/card/IP look the same + everywhere, so they run regardless of which country was selected. + """ + if locales is None: + return list(_REGISTRY) + wanted = {UNIVERSAL, *locales} + return [d for d in _REGISTRY if d.locale in wanted] + + +def mask(value: str) -> str: + """Mask a value for safe display in a report (keep shape, hide content).""" + if "@" in value: # email: keep first char + domain + local, _, domain = value.partition("@") + head = local[0] if local else "" + return f"{head}***@{domain}" + stripped = value.strip() + if len(stripped) <= 4: + return "*" * len(stripped) + return f"{stripped[:2]}{'*' * (len(stripped) - 3)}{stripped[-1]}" + + +def scan_text(text: str, locales: Optional[Iterable[str]] = None) -> "List[Finding]": + """Return all sensitive-data findings in ``text`` (non-destructive).""" + findings: List[Finding] = [] + for det in iter_detectors(locales): + # A detector may match surrounding context but expose only the sensitive + # span via a named ``id`` group (e.g. require "NRIC" before the number, + # but report just the number). Otherwise the whole match is the value. + report_id = "id" in det.pattern.groupindex + for m in det.pattern.finditer(text): + start, end = m.span("id") if report_id else m.span() + if start == -1: # an optional ``id`` group that didn't match this hit + continue + value = m.group("id") if report_id else m.group() + if det.validator and not det.validator(value): + continue + findings.append( + Finding( + detector=det.name, + category=det.category, + locale=det.locale, + severity=det.severity, + start=start, + end=end, + value=value, + preview=mask(value), + ) + ) + return findings + + +def luhn_valid(number: str) -> bool: + """Luhn checksum — filters most non-card digit runs (phone/IDs/etc).""" + digits = [int(c) for c in number if c.isdigit()] + if len(digits) < 13 or len(digits) > 19: + return False + total = 0 + for i, d in enumerate(reversed(digits)): + if i % 2 == 1: + d *= 2 + if d > 9: + d -= 9 + total += d + return total % 10 == 0 + + +# Importing the package registers the bundled detectors. Add a new locale module +# here (and as a file) and its detectors light up everywhere automatically. +from ingest.redaction import universal as _universal # noqa: E402,F401 +from ingest.redaction import sg as _sg # noqa: E402,F401 diff --git a/ingest/redaction/sg.py b/ingest/redaction/sg.py new file mode 100644 index 0000000..6b33b72 --- /dev/null +++ b/ingest/redaction/sg.py @@ -0,0 +1,94 @@ +"""Singapore (SG) sensitive-data detectors. + +This is the reference locale module — copy it to add your own country. + +A good locale detector is *precise*: a bare regex over chat text fires on +everything, so back it with a checksum/validator wherever the identifier has one +(see :func:`nric_valid`). High precision is what keeps the audit report +trustworthy instead of a wall of false positives. + +CONTRIBUTING +------------ +Add a country by creating ``ingest/redaction/.py`` (``cc`` = ISO 3166-1 +alpha-2, lower-case), registering detectors with :func:`ingest.redaction.make` +and ``locale=""``, then importing it in ``ingest/redaction/__init__``. +Open items for SG that make good first contributions: + * NRIC **M-series** (introduced 2022) uses a different checksum table — the + regex below intentionally matches only S/T/F/G so it never flags an + M-series number it can't verify. Add the M table + tests. + * UEN (business registration number) detector. +""" + +import re + +from ingest.redaction import make + +# NRIC/FIN checksum tables, indexed by (weighted_sum + offset) % 11. +_ST_SUFFIX = "JZIHGFEDCBA" # S (citizen) and T (citizen, 2000+) +_FG_SUFFIX = "XWUTRQPNMLK" # F and G (foreigner / long-term pass) +_WEIGHTS = (2, 7, 6, 5, 4, 3, 2) + + +def nric_valid(value: str) -> bool: + """Validate a Singapore NRIC/FIN by its check digit (S/T/F/G series).""" + value = value.strip().upper() + if len(value) != 9: + return False + prefix, digits, suffix = value[0], value[1:8], value[8] + if prefix not in "STFG" or not digits.isdigit(): + return False + total = sum(int(d) * w for d, w in zip(digits, _WEIGHTS)) + if prefix in "TG": # T and G shift the weighted sum by 4 + total += 4 + table = _ST_SUFFIX if prefix in "ST" else _FG_SUFFIX + return table[total % 11] == suffix + + +# Long form: full S/T/F/G + 7 digits + check letter, verified by checksum. +# Case-insensitive so "s1234567a" typed in lower-case is still caught (the +# validator upper-cases before checking). +make( + "sg_nric", + "NRIC/FIN", + "SG", + r"\b[STFG]\d{7}[A-Z]\b", + severity="high", + flags=re.IGNORECASE, + validator=nric_valid, +) + +# Short form: the last 3 digits + check letter (e.g. "123A"), the way people +# quote "the last 4 of my IC". It has no self-contained checksum and "123A" +# alone matches every block/unit number, so precision comes from REQUIRING an +# NRIC/IC/FIN keyword just before it. Only the ID span (named group) is +# reported, not the keyword. +make( + "sg_nric_short", + "NRIC/FIN (partial)", + "SG", + r"(?:nric|fin|\bic\b)\D{0,8}?(?P(?", "S123456", or "S(123456)" context. The trailing lookahead +# stops it matching the first 6 digits of a longer token (e.g. the NRIC +# "S1234567D", which would otherwise read as "S123456"). +make( + "sg_postal", + "POSTAL_CODE", + "SG", + r"(?:[Ss]ingapore\s+|\bS\(?)\d{6}\)?(?![\dA-Za-z])", + severity="low", +) diff --git a/ingest/redaction/universal.py b/ingest/redaction/universal.py new file mode 100644 index 0000000..7db1444 --- /dev/null +++ b/ingest/redaction/universal.py @@ -0,0 +1,76 @@ +"""Locale-independent detectors: same format the world over. + +Email, payment cards, IP/MAC addresses, and vendor API keys don't change by +country, so they live here and always run. Country-specific catches (national +IDs, local phone formats, postal codes) belong in a per-locale module instead. +""" + +import re + +from ingest.redaction import UNIVERSAL, luhn_valid, make + +# --- Contact / network ------------------------------------------------------- + +make( + "email", + "EMAIL", + UNIVERSAL, + r"\b[A-Za-z0-9._%+-]+@[A-Za-z0-9.-]+\.[A-Za-z]{2,}\b", + severity="medium", +) + +make( + "ipv4", + "IP_ADDRESS", + UNIVERSAL, + r"\b(?:(?:25[0-5]|2[0-4]\d|1?\d?\d)\.){3}(?:25[0-5]|2[0-4]\d|1?\d?\d)\b", + severity="low", +) + +make( + "ipv6", + "IP_ADDRESS", + UNIVERSAL, + r"\b(?:[A-Fa-f0-9]{1,4}:){2,7}[A-Fa-f0-9]{1,4}\b", + severity="low", +) + +make( + "mac", + "MAC_ADDRESS", + UNIVERSAL, + r"\b(?:[0-9A-Fa-f]{2}[:-]){5}[0-9A-Fa-f]{2}\b", + severity="low", +) + +# --- Financial --------------------------------------------------------------- + +# Broad 13–19 digit run (optionally space/dash grouped); Luhn rejects the noise. +make( + "credit_card", + "CARD_NUMBER", + UNIVERSAL, + r"\b(?:\d[ -]?){13,19}\b", + severity="high", + validator=luhn_valid, +) + +# --- Secrets / credentials --------------------------------------------------- + +make("openai_key", "API_KEY", UNIVERSAL, r"\bsk-[A-Za-z0-9]{20,}\b", severity="high") +make("aws_access_key", "API_KEY", UNIVERSAL, r"\bAKIA[0-9A-Z]{16}\b", severity="high") +make( + "github_token", + "API_KEY", + UNIVERSAL, + r"\bgh[pousr]_[A-Za-z0-9]{36,}\b", + severity="high", +) +make( + "private_key_block", + "PRIVATE_KEY", + UNIVERSAL, + r"-----BEGIN (?:RSA |EC |OPENSSH |DSA |PGP )?PRIVATE KEY-----", + severity="high", + flags=re.IGNORECASE, +) diff --git a/ingest/redactor.py b/ingest/redactor.py new file mode 100644 index 0000000..f8feefa --- /dev/null +++ b/ingest/redactor.py @@ -0,0 +1,275 @@ +"""Non-destructive sensitive-data audit over conversation samples. + +Runs the regex detectors in :mod:`ingest.redaction` across every turn, writes an +audit report, and prints a warning summary. By default **nothing is changed** — +the user reviews the report and decides whether to act. Acting is opt-in via +:func:`apply` (wired to the CLI's ``--redact`` flag): + + - "replace": swap each detected span for a ``[CATEGORY]`` placeholder, keeping + conversational structure intact for training. + - "drop": discard any conversation that contains a detected item. + +Detection is regex-based and locale-aware (Singapore-first); see +``ingest/redaction`` to add coverage for more countries. +""" + +import json +import re +from collections import defaultdict +from typing import Iterable, List, Optional + +from ingest import redaction + +DEFAULT_LOCALES = ["SG"] # universal detectors always run in addition to these +_MAX_CONSECUTIVE_LLM_FAILURES = 5 # abort the LLM pass if the endpoint keeps failing + + +def scan_samples(samples, locales: Optional[Iterable[str]] = None) -> dict: + """Scan every turn and return an audit report (no mutation).""" + if locales is None: + locales = DEFAULT_LOCALES + + findings = [] + for ci, turns in enumerate(samples): + for ti, turn in enumerate(turns): + for f in redaction.scan_text(turn.get("text", ""), locales): + findings.append({ + "conversation": ci, + "turn": ti, + "role": turn.get("role"), + "category": f.category, + "detector": f.detector, + "severity": f.severity, + "preview": f.preview, + }) + + summary = {} + convs_per_cat = defaultdict(set) + for f in findings: + s = summary.setdefault( + f["category"], {"hits": 0, "conversations": 0, "severity": f["severity"]} + ) + s["hits"] += 1 + convs_per_cat[f["category"]].add(f["conversation"]) + for cat, s in summary.items(): + s["conversations"] = len(convs_per_cat[cat]) + + return { + "conversations_scanned": len(samples), + "total_findings": len(findings), + "locales": list(locales), + "summary": summary, + "findings": findings, + } + + +def write_report(report: dict, path: str) -> None: + with open(path, "w", encoding="utf-8") as f: + json.dump(report, f, indent=2, ensure_ascii=False) + + +def print_summary(report: dict, report_path: str, mode: str = "off") -> None: + n = report["total_findings"] + if n == 0: + print("[redactor] No sensitive data detected by regex scan.") + return + print( + f"[redactor] WARNING: {n} potential sensitive item(s) detected across " + f"{report['conversations_scanned']} conversations:" + ) + for cat, s in sorted(report["summary"].items(), key=lambda kv: -kv[1]["hits"]): + print( + f" {cat:22s} {s['hits']:4d} hit(s) in {s['conversations']:3d} " + f"conversation(s) [{s['severity']}]" + ) + print(f"[redactor] Full report: {report_path}") + if mode == "off": + print( + "[redactor] Nothing was removed. Review it, then re-run with " + "--redact replace (placeholder) or --redact drop (remove conversations)." + ) + + +def _replace_spans(text: str, spans) -> str: + """Replace ``(start, end, category)`` spans with ``[CATEGORY]`` placeholders. + + Overlapping spans are **merged** so the full extent of every detected span is + redacted — no partial leaks (e.g. a shorter span can't shadow part of a longer + one). Each merged region is labelled by its longest contributing span. + Applied right-to-left so earlier offsets stay valid. + """ + merged = [] # [start, end, label_cat, label_len] + for start, end, cat in sorted(set(spans)): + if merged and start < merged[-1][1]: # overlaps the previous region + region = merged[-1] + region[1] = max(region[1], end) + if end - start > region[3]: # longer span -> its label wins + region[2], region[3] = cat, end - start + else: + merged.append([start, end, cat, end - start]) + for start, end, cat, _ in reversed(merged): + text = text[:start] + f"[{cat}]" + text[end:] + return text + + +def apply(samples, mode: str, locales: Optional[Iterable[str]] = None, + llm_findings: Optional[List[dict]] = None) -> List: + """Return samples with detected data handled per ``mode``. + + ``mode`` is "replace" (swap spans for ``[CATEGORY]``) or "drop" (remove any + conversation containing a detection). Regex spans are re-derived per turn; + optional ``llm_findings`` (which carry their own offsets) are applied too. + """ + if locales is None: + locales = DEFAULT_LOCALES + + llm_by_turn = defaultdict(list) + for f in llm_findings or []: + llm_by_turn[(f["conversation"], f["turn"])].append( + (f["start"], f["end"], f["category"]) + ) + + out = [] + for ci, turns in enumerate(samples): + new_turns = [] + drop = False + for ti, turn in enumerate(turns): + text = turn.get("text", "") + spans = [(f.start, f.end, f.category) for f in redaction.scan_text(text, locales)] + spans += llm_by_turn.get((ci, ti), []) + if not spans: + new_turns.append(turn) + continue + if mode == "drop": + drop = True + break + replaced = dict(turn) + replaced["text"] = _replace_spans(text, spans) + new_turns.append(replaced) + if not drop: + out.append(new_turns) + return out + + +# --- Optional LLM detector (Tier 3) ------------------------------------------ +# +# Regex can't catch names or context-dependent secrets. When enabled, the LLM +# reads each conversation and points at sensitive spans *verbatim* (it never +# rewrites the text — that stays the user's decision). Findings flow into the +# same report and the same apply() step as the regex tier. The client/endpoint +# plumbing (incl. LLM_API_BASE_URL for local servers) is shared with the quality +# validator via ingest.llm. + +_LLM_PROMPT = """You are a privacy auditor. Identify spans of SENSITIVE or +PERSONALLY IDENTIFYING information in the conversation below: real people's +names, contact details, addresses, financial or government IDs, credentials, +or health/legal/financial specifics that could identify someone. + +Each turn is numbered "[i] ROLE: text". Do NOT rewrite anything. For each +finding, copy the offending substring EXACTLY as it appears so it can be located. + +Respond with ONLY this JSON: +{{"findings": [{{"turn": , "text": "", "category": "", "severity": "low|medium|high"}}]}} + +Conversation: +{conversation}""" + + +def _format_conversation(turns) -> str: + return "\n".join( + f"[{i}] {t.get('role', '?').upper()}: {t.get('text', '').strip()}" + for i, t in enumerate(turns) + ) + + +def _llm_audit_conversation(client, model, turns) -> List[dict]: + from ingest import llm + + prompt = _LLM_PROMPT.format(conversation=_format_conversation(turns)) + raw = llm.chat(client, model, prompt, max_tokens=512) + match = re.search(r"\{.*\}", raw, re.DOTALL) + if not match: + raise ValueError(f"No JSON object in LLM response: {raw!r}") + return json.loads(match.group()).get("findings", []) + + +def llm_scan_samples(samples, client, model) -> List[dict]: + """LLM pass returning verbatim-located findings (with offsets, in memory). + + Each finding is verified by locating the model's span in the turn text; a + paraphrased span that can't be found is reported as a soft-miss and skipped + rather than trusting an offset we can't confirm. + """ + findings = [] + consecutive_failures = 0 + for ci, turns in enumerate(samples): + try: + raw = _llm_audit_conversation(client, model, turns) + consecutive_failures = 0 + except Exception as e: + consecutive_failures += 1 + print(f"[redactor] LLM scan failed on conversation {ci}: {e}") + if consecutive_failures >= _MAX_CONSECUTIVE_LLM_FAILURES: + print("[redactor] Too many consecutive LLM failures — aborting LLM scan.") + break + continue + for rf in raw: + try: + ti = int(rf["turn"]) + span = str(rf["text"]) + except (KeyError, ValueError, TypeError): + continue + if not (0 <= ti < len(turns)) or not span: + continue + text = turns[ti].get("text", "") + # Record every non-overlapping occurrence — a repeated name/number + # must not leak just because only the first was redacted. + start, located = 0, False + while True: + idx = text.find(span, start) + if idx < 0: + break + located = True + findings.append({ + "conversation": ci, + "turn": ti, + "role": turns[ti].get("role"), + "category": str(rf.get("category", "PII")), + "detector": "llm", + "severity": str(rf.get("severity", "medium")), + "start": idx, + "end": idx + len(span), + "preview": redaction.mask(span), + }) + start = idx + len(span) + if not located: + print(f"[redactor] LLM span not found verbatim (conv {ci}, turn {ti}): {span!r}") + return findings + + +def merge_llm_findings(report: dict, llm_findings: List[dict]) -> dict: + """Fold LLM findings into a regex report (masked previews only; no raw spans).""" + for f in llm_findings: + report["findings"].append({ + "conversation": f["conversation"], + "turn": f["turn"], + "role": f["role"], + "category": f["category"], + "detector": "llm", + "severity": f["severity"], + "preview": f["preview"], + }) + convs_per_cat = defaultdict(set) + for f in report["findings"]: + convs_per_cat[f["category"]].add(f["conversation"]) + summary = {} + for f in report["findings"]: + s = summary.setdefault( + f["category"], {"hits": 0, "conversations": 0, "severity": f["severity"]} + ) + s["hits"] += 1 + for cat, s in summary.items(): + s["conversations"] = len(convs_per_cat[cat]) + report["summary"] = summary + report["total_findings"] = len(report["findings"]) + return report diff --git a/ingest/validator.py b/ingest/validator.py index b702ddd..c7cd0ea 100644 --- a/ingest/validator.py +++ b/ingest/validator.py @@ -1,149 +1,189 @@ """ -Optional LLM-based conversation quality validator. +Optional LLM-based conversation auditor. -Controlled via environment variables: - DIALOGSMITH_LLM_VALIDATE=true/false (default: true if ANTHROPIC_API_KEY is set) - DIALOGSMITH_LLM_MODEL=... (default: claude-haiku-4-5-20251001) - ANTHROPIC_API_KEY=... +Uses the shared OpenAI-compatible client (see :mod:`ingest.llm`), so it runs +against OpenAI or any local server. Controlled by the ``LLM_*`` environment +variables documented there (``LLM_VALIDATE``, ``LLM_API_BASE_URL``, ``LLM_MODEL``, +``LLM_API_KEY``). -Each conversation sample is scored on two axes: +Each conversation sample is audited on three axes: - coherence: does this read as a natural, continuous conversation? - quality: is this a meaningful exchange worth training on? + - pairing: does each assistant turn actually respond to what came before? -Samples that fail either check are excluded from the output. -A summary of filtered samples is printed so the user can audit decisions. +Because the heuristic grouper can over-merge, the auditor may also *repair* a +sample by proposing split points rather than only keeping or dropping it: + - action "keep": use as-is + - action "split": cut after the given turn indices into independent samples + - action "drop": discard entirely + +A summary of every decision is printed so the user can audit the auditor. """ import json -import os import re -VALIDATE_ENV = "DIALOGSMITH_LLM_VALIDATE" -MODEL_ENV = "DIALOGSMITH_LLM_MODEL" -DEFAULT_MODEL = "claude-haiku-4-5-20251001" - -COHERENCE_THRESHOLD = 0.5 # 0–1, below this the conversation is considered incoherent -QUALITY_THRESHOLD = 0.5 # 0–1, below this the sample is considered low-quality - - -def _should_validate(): - val = os.environ.get(VALIDATE_ENV, "").strip().lower() - if val == "false": - return False - if val == "true": - return True - # Default: enable if API key is present - return bool(os.environ.get("ANTHROPIC_API_KEY", "").strip()) +from ingest import llm - -def _get_client(): - try: - import anthropic - except ImportError: - raise ImportError( - "The 'anthropic' package is required for LLM validation. " - "Install it with: pip install anthropic" - ) - api_key = os.environ.get("ANTHROPIC_API_KEY", "").strip() - if not api_key: - raise EnvironmentError( - "ANTHROPIC_API_KEY is not set. " - f"Set {VALIDATE_ENV}=false to disable validation." - ) - return anthropic.Anthropic(api_key=api_key) +_MAX_CONSECUTIVE_LLM_FAILURES = 5 # abort validation if the endpoint keeps failing +COHERENCE_THRESHOLD = 0.5 # below this the conversation is considered incoherent +QUALITY_THRESHOLD = 0.5 # below this the sample is considered low-quality +PAIRING_THRESHOLD = 0.5 # below this the turns don't respond to each other def _format_conversation(turns): + """Number every turn so the model can reference split points by index.""" lines = [] - for turn in turns: + for i, turn in enumerate(turns): role = turn.get("role", "unknown").upper() text = turn.get("text", "").strip() - lines.append(f"{role}: {text}") + lines.append(f"[{i}] {role}: {text}") return "\n".join(lines) def _score_sample(client, model, turns): - """ - Ask the LLM to score a conversation sample. - Returns (coherence: float, quality: float, reason: str). + """Ask the LLM to audit a conversation sample. + + Returns a dict: coherence, quality, pairing (floats), action + ("keep"|"split"|"drop"), split_after (list[int]), reason (str). """ conversation_text = _format_conversation(turns) - prompt = f"""You are evaluating a conversation sample for use in fine-tuning a language model. + prompt = f"""You are auditing a conversation sample for fine-tuning a language model +to imitate the ASSISTANT speaker. The conversation was segmented by a heuristic +that can wrongly merge unrelated exchanges, so judge it carefully. -Rate the following conversation on two dimensions, each from 0.0 to 1.0: +Each turn is numbered like "[i] ROLE: text". -1. coherence: Does this read as a natural, continuous conversation where each message follows logically from the previous? (0 = completely disjointed, 1 = perfectly coherent) -2. quality: Is this a meaningful, substantive exchange worth training on? Penalise one-word replies, pure greetings, or exchanges with no informational content. (0 = worthless, 1 = highly valuable) +Rate from 0.0 to 1.0: +1. coherence: does this read as one natural, continuous conversation? +2. quality: is this a meaningful exchange worth training on? Penalise pure + greetings, one-word replies, and content-free chatter. +3. pairing: does each ASSISTANT turn actually respond to the USER turn(s) before + it? (0 = replies are mismatched/non-sequiturs, 1 = every reply clearly fits) -Respond with ONLY a JSON object in this exact format: -{{"coherence": , "quality": , "reason": ""}} +Then choose an action: +- "keep": the sample is good as one conversation. +- "split": it is really two or more separate conversations. Give "split_after" + as the list of turn indices AFTER which to cut (e.g. [3] cuts between turn 3 + and 4). +- "drop": it is not usable. + +Respond with ONLY this JSON: +{{"coherence": , "quality": , "pairing": , + "action": "keep"|"split"|"drop", "split_after": [...], "reason": ""}} Conversation: {conversation_text}""" - response = client.messages.create( - model=model, - max_tokens=128, - messages=[{"role": "user", "content": prompt}], - ) - - raw = response.content[0].text.strip() - # The model may wrap the JSON in markdown fences or prose; extract the object. + raw = llm.chat(client, model, prompt, max_tokens=200) match = re.search(r"\{.*\}", raw, re.DOTALL) if not match: raise ValueError(f"No JSON object found in LLM response: {raw!r}") result = json.loads(match.group()) - return float(result["coherence"]), float(result["quality"]), result.get("reason", "") + return { + "coherence": float(result["coherence"]), + "quality": float(result["quality"]), + "pairing": float(result.get("pairing", 1.0)), + "action": str(result.get("action", "keep")).lower(), + "split_after": [int(i) for i in result.get("split_after", []) or []], + "reason": result.get("reason", ""), + } + + +def _apply_split(turns, split_after): + """Cut ``turns`` after each given index into independent samples.""" + cuts = sorted({i for i in split_after if 0 <= i < len(turns) - 1}) + if not cuts: + return [turns] + pieces, start = [], 0 + for idx in cuts: + pieces.append(turns[start:idx + 1]) + start = idx + 1 + pieces.append(turns[start:]) + return [p for p in pieces if p] + + +def _has_both_roles(turns): + roles = {t["role"] for t in turns} + return "user" in roles and "assistant" in roles def validate_samples(samples): """ - Validate a list of conversation samples. + Audit a list of conversation samples. Each sample is a list of {"role": ..., "text": ...} dicts (as produced by ingest.core.build_samples). - Returns filtered list of samples that pass validation. - If validation is disabled or unavailable, returns all samples unchanged. + Returns the filtered/repaired list of samples. If validation is disabled or + unavailable, returns all samples unchanged. """ - if not _should_validate(): + if not llm.should_validate(): print("[validator] LLM validation disabled — skipping.") return samples try: - client = _get_client() + client = llm.get_client() except (ImportError, EnvironmentError) as e: print(f"[validator] WARNING: {e}") print("[validator] Skipping LLM validation and returning all samples.") return samples - model = os.environ.get(MODEL_ENV, DEFAULT_MODEL).strip() - print(f"[validator] Running LLM validation with model: {model}") + model = llm.model() + print(f"[validator] Auditing with model: {model} via {llm.endpoint_label()}") passed = [] filtered = [] + split_count = 0 + consecutive_failures = 0 for i, turns in enumerate(samples): try: - coherence, quality, reason = _score_sample(client, model, turns) + r = _score_sample(client, model, turns) + consecutive_failures = 0 except Exception as e: + consecutive_failures += 1 print(f"[validator] Sample {i}: scoring failed ({e}), keeping sample.") passed.append(turns) + if consecutive_failures >= _MAX_CONSECUTIVE_LLM_FAILURES: + print("[validator] Too many consecutive LLM failures — keeping remaining samples unvalidated.") + passed.extend(samples[i + 1:]) + break continue - if coherence < COHERENCE_THRESHOLD: - filtered.append((i, "incoherent", coherence, quality, reason)) - elif quality < QUALITY_THRESHOLD: - filtered.append((i, "low-quality", coherence, quality, reason)) + low = ( + r["coherence"] < COHERENCE_THRESHOLD or + r["quality"] < QUALITY_THRESHOLD or + r["pairing"] < PAIRING_THRESHOLD + ) + # Try repair-by-split first: an over-merged sample scores low on + # coherence/pairing, so checking `low` before `split` would always drop + # the very samples split is meant to rescue. + if r["action"] == "split": + pieces = [p for p in _apply_split(turns, r["split_after"]) if _has_both_roles(p)] + if pieces: + passed.extend(pieces) + split_count += 1 + else: + filtered.append((i, "split-empty", r)) + elif r["action"] == "drop" or low: + filtered.append((i, "dropped", r)) else: passed.append(turns) - print(f"[validator] {len(passed)} passed, {len(filtered)} filtered out of {len(samples)} total.") + print( + f"[validator] {len(passed)} samples kept ({split_count} from splits), " + f"{len(filtered)} dropped, from {len(samples)} input samples." + ) if filtered: - print("[validator] Filtered samples:") - for idx, reason_type, coh, qual, reason in filtered: - print(f" sample {idx:4d} | {reason_type:12s} | coherence={coh:.2f} quality={qual:.2f} | {reason}") + print("[validator] Dropped samples:") + for idx, kind, r in filtered: + print( + f" sample {idx:4d} | {kind:11s} | " + f"coh={r['coherence']:.2f} qual={r['quality']:.2f} pair={r['pairing']:.2f} " + f"| {r['reason']}" + ) return passed diff --git a/requirements.txt b/requirements.txt index 44d86bc..6d21e72 100644 --- a/requirements.txt +++ b/requirements.txt @@ -6,6 +6,7 @@ llamafactory==0.9.4 # (see https://pytorch.org/get-started/locally/). Installing llamafactory pulls # a torch build, but it may not match your CUDA version. -# Optional — only needed when LLM-based dataset validation is enabled -# (DIALOGSMITH_LLM_VALIDATE / ANTHROPIC_API_KEY). Safe to remove otherwise. -anthropic>=0.39 +# Optional — only needed for the LLM features (quality validation, LLM +# redaction). Uses the OpenAI-compatible API, so it works with OpenAI or any +# local server (Ollama, vLLM, LM Studio, ...). Safe to remove otherwise. +openai>=1.0 diff --git a/scripts/convert_to_sharegpt.py b/scripts/convert_to_sharegpt.py deleted file mode 100644 index 7d99723..0000000 --- a/scripts/convert_to_sharegpt.py +++ /dev/null @@ -1,29 +0,0 @@ -#!/usr/bin/env python3 -"""DEPRECATED shim — kept for backwards compatibility, removed in a future release. - -The new pipeline writes ShareGPT directly: - - python -m ingest --source telegram --format sharegpt - -This shim still converts an existing data/chat_dataset.jsonl into -data/chat_sharegpt.json, delegating to the new ``ingest`` package. -""" - -import os -import sys - -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from ingest import sharegpt # noqa: E402 - -INPUT_PATH = "./data/chat_dataset.jsonl" -OUTPUT_PATH = "./data/chat_sharegpt.json" - -if __name__ == "__main__": - sys.stderr.write( - "[deprecated] scripts/convert_to_sharegpt.py -> use: " - "python -m ingest --source telegram --format sharegpt\n" - ) - samples = sharegpt.load_jsonl_samples(INPUT_PATH) - written = sharegpt.write_sharegpt(samples, OUTPUT_PATH) - print(f"Converted {written} valid conversation samples to ShareGPT format.") diff --git a/scripts/telegram_extract.py b/scripts/telegram_extract.py deleted file mode 100644 index 548d004..0000000 --- a/scripts/telegram_extract.py +++ /dev/null @@ -1,28 +0,0 @@ -#!/usr/bin/env python3 -"""DEPRECATED shim — kept for backwards compatibility, removed in a future release. - -Use the cross-platform CLI instead: - - python -m ingest --source telegram --format jsonl - -This shim reproduces the old behaviour (Telegram result.json -> -data/chat_dataset.jsonl) by delegating to the new ``ingest`` package. -""" - -import os -import sys - -# Allow running as `python scripts/telegram_extract.py` from the repo root. -sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) - -from ingest.cli import main # noqa: E402 - -if __name__ == "__main__": - sys.stderr.write( - "[deprecated] scripts/telegram_extract.py -> use: " - "python -m ingest --source telegram --format jsonl\n" - ) - raise SystemExit( - main(["--source", "telegram", "--format", "jsonl", - "--output", "./data/chat_dataset.jsonl"]) - ) diff --git a/setup.bat b/setup.bat index 48449bf..60f78d8 100644 --- a/setup.bat +++ b/setup.bat @@ -19,8 +19,8 @@ if errorlevel 1 (echo Failed to install dependencies. & exit /b 1) echo [3/4] Preparing .env... if not exist ".env" ( - copy ".env.example" ".env" >nul - echo Created .env from .env.example - edit it to enable optional LLM validation. + copy "example.env" ".env" >nul + echo Created .env from example.env - edit it to enable optional LLM features. ) echo [4/4] Processing Telegram export (data\result.json -^> data\chat_sharegpt.json)... diff --git a/setup.sh b/setup.sh index 41ad5bd..e6afab3 100755 --- a/setup.sh +++ b/setup.sh @@ -20,8 +20,8 @@ echo "[2/4] Installing dependencies (this can take a while)..." echo "[3/4] Preparing .env..." if [ ! -f .env ]; then - cp .env.example .env - echo " Created .env from .env.example — edit it to enable optional LLM validation." + cp example.env .env + echo " Created .env from example.env — edit it to enable optional LLM features." fi echo "[4/4] Processing Telegram export (data/result.json -> data/chat_sharegpt.json)..." diff --git a/tests/test_ingest.py b/tests/test_ingest.py index 72b05fb..ca4ab30 100644 --- a/tests/test_ingest.py +++ b/tests/test_ingest.py @@ -21,6 +21,7 @@ from ingest import core, sharegpt from ingest.adapters import available_sources, get_adapter from ingest.adapters.telegram import TelegramAdapter +from ingest.message import NormalizedMessage SELF = "Yu Sheng" @@ -117,6 +118,21 @@ def test_self_name_override(self): alice = [m for m in msgs if m.sender_id == "Alice"][0] self.assertTrue(alice.sender_is_self) + def test_missing_from_becomes_unknown(self): + # "from" can be missing/None (anonymous channel posts); sender_id must + # stay a str rather than None. + with tempfile.TemporaryDirectory() as d: + path = os.path.join(d, "result.json") + with open(path, "w", encoding="utf-8") as f: + json.dump({"personal_information": {"first_name": "Yu", "last_name": "Sheng"}, + "chats": {"list": [{"id": 1, "messages": [ + _msg(None, 100, "anon post"), + _msg(SELF, 110, "reply")]}]}}, f) + msgs = TelegramAdapter().parse(path) + anon = [m for m in msgs if m.timestamp == 100][0] + self.assertEqual(anon.sender_id, "Unknown") + self.assertFalse(anon.sender_is_self) + def test_undetectable_self_name_raises(self): # Without personal_information, auto-detection yields "" — which would # silently drop every conversation. Must fail loudly instead. @@ -148,6 +164,57 @@ def test_gap_splits_conversations(self): self.assertEqual(first_two, EXPECTED_SHAREGPT[:2]) +def _nm(chat, ts, sender, is_self, text, mid=None, reply=None): + return NormalizedMessage( + chat_id=chat, timestamp=ts, sender_id=sender, sender_is_self=is_self, + text=text, message_id=mid, reply_to_id=reply, + ) + + +class ReplyThreadingTest(unittest.TestCase): + def test_reply_stitches_gap_split_conversations(self): + # Two messages an hour+ apart would split into two conversations, but the + # second replies to the first -> they must end up in one sample. + msgs = [ + _nm("c", 1000, "Alice", False, "you free this weekend?", mid="1"), + _nm("c", 1000 + 8000, "Yu", True, "yeah sun works", mid="2", reply="1"), + ] + samples = core.build_samples(msgs) + self.assertEqual(len(samples), 1) + self.assertEqual([t["role"] for t in samples[0]], ["user", "assistant"]) + + def test_no_reply_data_keeps_time_split(self): + # Same timing, no reply link -> still two conversations (one is one-sided + # and dropped), proving threading is a no-op without reply metadata. + msgs = [ + _nm("c", 1000, "Alice", False, "you free this weekend?"), + _nm("c", 1000 + 8000, "Yu", True, "yeah sun works"), + ] + self.assertEqual(core.build_samples(msgs), []) + + +class MultiSpeakerTest(unittest.TestCase): + def _group(self): + return [ + _nm("g", 1, "Bob", False, "q1"), + _nm("g", 2, "Carol", False, "q2"), + _nm("g", 3, "Yu", True, "answer"), + ] + + def test_default_collapses_other_side(self): + out = sharegpt.to_sharegpt(core.build_samples(self._group())) + self.assertEqual(out[0]["conversations"][0], {"from": "human", "value": "q1\nq2"}) + + def test_multi_speaker_labels_users_not_assistant(self): + out = sharegpt.to_sharegpt(core.build_samples(self._group(), multi_speaker=True)) + convs = out[0]["conversations"] + # Distinct speakers stay distinct and are labelled... + self.assertEqual(convs[0], {"from": "human", "value": "Bob: q1"}) + self.assertEqual(convs[1], {"from": "human", "value": "Carol: q2"}) + # ...but the owner's (assistant) turn is never labelled. + self.assertEqual(convs[2], {"from": "gpt", "value": "answer"}) + + class ShareGptTest(unittest.TestCase): def test_role_mapping_and_drop_one_sided(self): samples = [ @@ -167,6 +234,58 @@ def test_jsonl_roundtrip(self): self.assertEqual(sharegpt.load_jsonl_samples(p), samples) +class ValidatorSplitTest(unittest.TestCase): + def test_apply_split_cuts_after_indices(self): + from ingest.validator import _apply_split + turns = [{"role": "user", "text": "a"}, {"role": "assistant", "text": "b"}, + {"role": "user", "text": "c"}, {"role": "assistant", "text": "d"}] + pieces = _apply_split(turns, [1]) + self.assertEqual(len(pieces), 2) + self.assertEqual(pieces[0], turns[:2]) + self.assertEqual(pieces[1], turns[2:]) + + def test_apply_split_ignores_out_of_range(self): + from ingest.validator import _apply_split + turns = [{"role": "user", "text": "a"}, {"role": "assistant", "text": "b"}] + # Index at/after the last turn is meaningless -> no split. + self.assertEqual(_apply_split(turns, [1, 9]), [turns]) + + def test_has_both_roles(self): + from ingest.validator import _has_both_roles + self.assertTrue(_has_both_roles([{"role": "user"}, {"role": "assistant"}])) + self.assertFalse(_has_both_roles([{"role": "user"}, {"role": "user"}])) + + +class _FakeOpenAI: + """Stub OpenAI-compatible client returning canned JSON (no network).""" + def __init__(self, text): + import types + msg = types.SimpleNamespace(content=text) + resp = types.SimpleNamespace(choices=[types.SimpleNamespace(message=msg)]) + self.chat = types.SimpleNamespace( + completions=types.SimpleNamespace(create=lambda **kw: resp)) + + +class ValidatorSplitPriorityTest(unittest.TestCase): + def test_split_runs_even_when_scores_are_low(self): + from ingest import validator, llm + canned = ('{"coherence":0.2,"quality":0.2,"pairing":0.2,' + '"action":"split","split_after":[1],"reason":"two convos"}') + orig_get, orig_should = llm.get_client, llm.should_validate + llm.get_client = lambda: _FakeOpenAI(canned) + llm.should_validate = lambda: True + os.environ["LLM_MODEL"] = "x" + try: + sample = [{"role": "user", "text": "a"}, {"role": "assistant", "text": "b"}, + {"role": "user", "text": "c"}, {"role": "assistant", "text": "d"}] + out = validator.validate_samples([sample]) + # Low scores would previously drop it; now split runs first -> 2 pieces. + self.assertEqual(len(out), 2) + finally: + llm.get_client, llm.should_validate = orig_get, orig_should + os.environ.pop("LLM_MODEL", None) + + class RegistryTest(unittest.TestCase): def test_telegram_registered(self): self.assertIn("telegram", available_sources()) @@ -180,7 +299,7 @@ def test_unknown_source_raises(self): class CliTest(unittest.TestCase): def test_end_to_end_sharegpt(self): from ingest.cli import main - os.environ["DIALOGSMITH_LLM_VALIDATE"] = "false" # no API calls + os.environ["LLM_VALIDATE"] = "false" # no API calls with tempfile.TemporaryDirectory() as d: inp = _write_fixture(d) out = os.path.join(d, "chat_sharegpt.json") @@ -193,6 +312,25 @@ def test_unknown_source_exit_code(self): from ingest.cli import main self.assertEqual(main(["--source", "nope"]), 2) + def test_redact_applies_even_when_scan_skipped(self): + # --redact must still redact when the scan is skipped (no silent leak). + from ingest.cli import main + os.environ["LLM_VALIDATE"] = "false" + with tempfile.TemporaryDirectory() as d: + inp = os.path.join(d, "result.json") + with open(inp, "w", encoding="utf-8") as f: + json.dump({"personal_information": {"first_name": "Yu", "last_name": "Sheng"}, + "chats": {"list": [{"id": 1, "messages": [ + _msg("Alice", 100, "mail me at a@b.com"), + _msg(SELF, 110, "ok")]}]}}, f) + out = os.path.join(d, "out.json") + rc = main(["--source", "telegram", "--input", inp, "--output", out, + "--skip-redact-scan", "--redact", "replace"]) + self.assertEqual(rc, 0) + blob = json.dumps(json.load(open(out, encoding="utf-8"))) + self.assertIn("[EMAIL]", blob) + self.assertNotIn("a@b.com", blob) + if __name__ == "__main__": unittest.main() diff --git a/tests/test_redaction.py b/tests/test_redaction.py new file mode 100644 index 0000000..9b8ff15 --- /dev/null +++ b/tests/test_redaction.py @@ -0,0 +1,209 @@ +"""Unit tests for regex-based sensitive-data detection (stdlib, no network).""" + +import os +import sys +import types +import unittest + +sys.path.insert(0, os.path.dirname(os.path.dirname(os.path.abspath(__file__)))) + +from ingest import redaction, redactor +from ingest.redaction.sg import nric_valid + + +def _categories(text, locales=None): + return {f.category for f in redaction.scan_text(text, locales)} + + +class UniversalTest(unittest.TestCase): + def test_email(self): + finds = redaction.scan_text("ping me at john.doe@acme.co please") + self.assertEqual([f.category for f in finds], ["EMAIL"]) + self.assertEqual(finds[0].preview, "j***@acme.co") # masked, not raw + + def test_credit_card_luhn(self): + # Valid Visa test number passes; same length with bad checksum does not. + self.assertIn("CARD_NUMBER", _categories("card 4111 1111 1111 1111")) + self.assertNotIn("CARD_NUMBER", _categories("ref 4111 1111 1111 1112")) + + def test_api_keys(self): + self.assertIn("API_KEY", _categories("token sk-abcdefghij0123456789xyz")) + self.assertIn("API_KEY", _categories("AKIAIOSFODNN7EXAMPLE")) + + def test_ipv4(self): + self.assertIn("IP_ADDRESS", _categories("server at 192.168.1.10")) + self.assertNotIn("IP_ADDRESS", _categories("version 999.999.1.1")) + + +class SingaporeTest(unittest.TestCase): + def test_nric_checksum(self): + # S0000001I is a well-formed example; flipping the suffix must fail. + self.assertTrue(nric_valid("S0000001I")) + self.assertFalse(nric_valid("S0000001A")) + + def test_nric_detected_only_when_valid(self): + self.assertIn("NRIC/FIN", _categories("my ic is S0000001I", ["SG"])) + self.assertNotIn("NRIC/FIN", _categories("code S0000001A", ["SG"])) + + def test_nric_case_insensitive(self): + self.assertIn("NRIC/FIN", _categories("ic s0000001i", ["SG"])) + + def test_nric_short_form_requires_context(self): + # With an NRIC/IC keyword nearby it's flagged... + self.assertIn("NRIC/FIN (partial)", _categories("NRIC 123A", ["SG"])) + self.assertIn("NRIC/FIN (partial)", _categories("my IC is 567B", ["SG"])) + # ...but a bare block/unit number is not. + self.assertNotIn("NRIC/FIN (partial)", _categories("Blk 123A Clementi", ["SG"])) + + def test_nric_short_form_reports_only_the_id(self): + finds = [ + f for f in redaction.scan_text("NRIC 123A", ["SG"]) + if f.category == "NRIC/FIN (partial)" + ] + self.assertEqual(finds[0].value, "123A") # keyword excluded from the span + + def test_phone(self): + self.assertIn("PHONE", _categories("call 9123 4567", ["SG"])) + self.assertIn("PHONE", _categories("call +65 9123 4567", ["SG"])) + + def test_postal_requires_context_and_not_nric(self): + self.assertIn("POSTAL_CODE", _categories("Singapore 560123", ["SG"])) + self.assertIn("POSTAL_CODE", _categories("address S123456", ["SG"])) + # Must NOT fire on the leading 6 digits of an NRIC. + self.assertNotIn("POSTAL_CODE", _categories("ic S1234567D", ["SG"])) + + def test_locale_filtering(self): + # SG detectors don't run when only universal locale is requested. + self.assertNotIn("NRIC/FIN", _categories("ic S0000001I", [])) + + +class OptionalIdGroupTest(unittest.TestCase): + def test_unmatched_optional_id_group_is_skipped(self): + # A detector whose ``id`` group is optional must not crash / mis-offset + # when that group doesn't match a given hit. + det = redaction.make("tmp_opt", "TMP", redaction.UNIVERSAL, r"X(?P\d+)?") + try: + self.assertNotIn("TMP", _categories("X here", [])) # id unmatched -> skipped + finds = [f for f in redaction.scan_text("X42", []) if f.category == "TMP"] + self.assertEqual(finds[0].value, "42") + finally: + redaction._REGISTRY.remove(det) + + +class RegistryTest(unittest.TestCase): + def test_no_duplicate_names(self): + names = [d.name for d in redaction.iter_detectors()] + self.assertEqual(len(names), len(set(names))) + + def test_locales_available(self): + self.assertIn("SG", redaction.available_locales()) + self.assertIn("universal", redaction.available_locales()) + + +class RedactorStageTest(unittest.TestCase): + def _samples(self): + return [ + [{"role": "user", "text": "email me at a@b.com"}, + {"role": "assistant", "text": "sure thing"}], + [{"role": "user", "text": "nothing sensitive here"}, + {"role": "assistant", "text": "ok"}], + ] + + def test_scan_is_nondestructive_and_reports(self): + samples = self._samples() + report = redactor.scan_samples(samples) + self.assertEqual(report["total_findings"], 1) + self.assertIn("EMAIL", report["summary"]) + # Original samples untouched. + self.assertEqual(samples[0][0]["text"], "email me at a@b.com") + + def test_apply_replace_uses_placeholder(self): + out = redactor.apply(self._samples(), "replace") + self.assertEqual(out[0][0]["text"], "email me at [EMAIL]") + self.assertEqual(out[1][0]["text"], "nothing sensitive here") # untouched + + def test_apply_drop_removes_conversation(self): + out = redactor.apply(self._samples(), "drop") + self.assertEqual(len(out), 1) # the one with an email is dropped + self.assertEqual(out[0][0]["text"], "nothing sensitive here") + + +class _FakeClient: + """Stub OpenAI-compatible client returning a canned JSON body (no network).""" + + def __init__(self, text): + message = types.SimpleNamespace(content=text) + resp = types.SimpleNamespace(choices=[types.SimpleNamespace(message=message)]) + completions = types.SimpleNamespace(create=lambda **kw: resp) + self.chat = types.SimpleNamespace(completions=completions) + + +class LlmRedactionTest(unittest.TestCase): + def _samples(self): + return [[{"role": "user", "text": "hi I'm Alice from Acme"}, + {"role": "assistant", "text": "hello"}]] + + def test_verbatim_span_is_located_and_masked(self): + client = _FakeClient( + '{"findings":[{"turn":0,"text":"Alice","category":"NAME","severity":"high"}]}' + ) + finds = redactor.llm_scan_samples(self._samples(), client, "model") + self.assertEqual(len(finds), 1) + self.assertEqual(finds[0]["category"], "NAME") + self.assertEqual(finds[0]["start"], 7) # offset of "Alice" + self.assertEqual(finds[0]["end"], 12) + self.assertNotIn("Alice", finds[0]["preview"]) # masked + + def test_repeated_span_all_located(self): + # Every occurrence of a repeated span must be recorded, not just the first. + samples = [[{"role": "user", "text": "Alice told Alice about Alice"}, + {"role": "assistant", "text": "ok"}]] + client = _FakeClient( + '{"findings":[{"turn":0,"text":"Alice","category":"NAME","severity":"high"}]}' + ) + finds = redactor.llm_scan_samples(samples, client, "model") + self.assertEqual([f["start"] for f in finds], [0, 11, 23]) + + def test_unlocatable_span_is_dropped(self): + # Model paraphrased instead of copying -> can't verify -> skipped. + client = _FakeClient( + '{"findings":[{"turn":0,"text":"Bob","category":"NAME","severity":"high"}]}' + ) + self.assertEqual(redactor.llm_scan_samples(self._samples(), client, "model"), []) + + def test_merge_into_report(self): + report = redactor.scan_samples(self._samples()) # 0 regex findings + llm = [{"conversation": 0, "turn": 0, "role": "user", "category": "NAME", + "severity": "high", "start": 6, "end": 11, "preview": "Al**e"}] + redactor.merge_llm_findings(report, llm) + self.assertEqual(report["total_findings"], 1) + self.assertIn("NAME", report["summary"]) + self.assertNotIn("value", report["findings"][0]) # no raw span persisted + + def test_apply_replace_uses_llm_offsets(self): + llm = [{"conversation": 0, "turn": 0, "category": "NAME", + "start": 7, "end": 12}] + out = redactor.apply(self._samples(), "replace", llm_findings=llm) + self.assertEqual(out[0][0]["text"], "hi I'm [NAME] from Acme") + + def test_replace_spans_merges_overlaps(self): + from ingest.redactor import _replace_spans + # Overlapping spans merge so the WHOLE region is redacted (no partial + # leak); the merged region is labelled by its longest contributing span. + self.assertEqual(_replace_spans("abcdef", [(0, 3, "X"), (2, 5, "Y")]), "[X]f") + self.assertEqual( + _replace_spans("a@b.com x", [(0, 7, "EMAIL"), (2, 7, "DOMAIN")]), + "[EMAIL] x", + ) + # A longer, later-starting span must not be shadowed by a shorter earlier + # one — the whole region is covered and labelled EMAIL. + self.assertEqual( + _replace_spans("0123456789ABCDEFGHIJ", [(0, 8, "CTX"), (5, 20, "EMAIL")]), + "[EMAIL]", + ) + # Non-overlapping spans stay separate. + self.assertEqual(_replace_spans("a b c", [(0, 1, "A"), (4, 5, "C")]), "[A] b [C]") + + +if __name__ == "__main__": + unittest.main()