diff --git a/README.md b/README.md index 1b0a80f..b934e95 100644 --- a/README.md +++ b/README.md @@ -58,6 +58,22 @@ from clx.models import DocketEntry print(DocketEntry.objects.all().count()) ``` +## Data Ingestion + +You can bulk-add documents to a project using `add_docs()`: + +```python +from clx.models import Project + +project = Project.objects.get(name="My Project") +project.add_docs(["First document text", "Second document text"]) + +# Or with metadata: +project.add_docs([ + {"text": "Document text", "meta": {"source": "courtlistener"}}, +]) +``` + ## Development Here are a few tips for setting up your development environment. diff --git a/cc/CLAUDE.md b/cc/CLAUDE.md deleted file mode 100644 index a51ee85..0000000 --- a/cc/CLAUDE.md +++ /dev/null @@ -1,820 +0,0 @@ -# Claude Experiment: Automated Annotation Workflow - -## Overview - -This experiment explores automating the human annotation workflow for the docket-entry classifier project. The goal is to have Claude perform the iterative steps that humans currently do—creating synthetic annotations, grounding decision boundaries, generating training sets, and preparing data for BERT fine-tuning. - -## Project Context - -- **Target Project**: `docket-entry` - classifying docket entries (motions, orders, etc.) -- **Pipeline**: Synthetic annotations → Human grounding → AI annotation of training sets → BERT fine-tuning - -## Note on Dictation - -The user often uses dictation, so be forgiving of typos and odd formatting in spoken instructions. Defer to what's written in the code for canonical names and spellings (e.g., "Claude" may be transcribed as "Cloud"). - -## Important: Import Pattern - -Always import models using the shim pattern: -```python -from clx.models import Label, LabelHeuristic, LabelDecision, Project -``` -NOT `from clx.app.models import ...`. The shim at `clx/models.py` auto-initializes Django. - ---- - -## Step 1: Creating Heuristics for a Label - -### Purpose - -Heuristics partition the corpus into three buckets for efficient annotation: -- **Excluded**: High-confidence negatives (don't meet minimal conditions) -- **Neutral**: Uncertain cases (meet minimal but not likely conditions) -- **Likely**: High-confidence positives (meet both minimal and likely conditions) - -### Two Types of Heuristics - -#### 1. Query String Heuristics - -Simple keyword conditions using a mini-language: - -| Operator | Meaning | Example | -|----------|---------|---------| -| `,` | AND (all must match) | `motion, court` | -| `\|` | OR (any can match) | `motion\|filing` | -| `~` | NOT (negation) | `~denied` | -| `^` | Starts with | `^Summary` | - -**Precedence**: ORs are nested within ANDs. So `a, b|c` means `(a) AND (b OR c)`. - -**All matching is case-insensitive.** - -Examples: -- `complaint` - contains "complaint" anywhere -- `^motion` - starts with "motion" -- `motion, ~denied` - contains "motion" AND does not contain "denied" -- `motion|application|request` - contains any of these terms -- `^motion, court|judge` - starts with "motion" AND contains "court" or "judge" - -#### 2. Custom Function Heuristics - -For complex logic, define a decorated function in `clx/app/custom_heuristics.py`: - -```python -from clx.app.custom_heuristics import custom_heuristic - -def within_first(text, term, n): - """Helper: check if term appears in first n words.""" - first_n = " ".join(text.split()[:n]) - return term in first_n - -@custom_heuristic("docket-entry", "Motion") -def first_3_motion(text, **kwargs): - """Matches if 'motion' appears in the first 3 words.""" - return within_first(text.lower(), "motion", 3) -``` - -The decorator registers the function with: -- `project_id`: Which project this applies to -- `label_name`: Which label this is a heuristic for -- The function receives `text` and must return `True`/`False` - -After adding custom heuristics, sync them: -```python -LabelHeuristic.sync_custom_heuristics() -``` - -### Minimal vs Likely Conditions - -**Minimal Conditions** (`is_minimal=True`): -- Define what MUST be true for a positive example -- Used to exclude obvious negatives -- Should be conservative—avoid false exclusions -- Example: A "Complaint" should contain "complaint" (or common misspellings) - -**Likely Conditions** (`is_likely=True`): -- Define patterns that strongly suggest a positive -- Used to identify easy positive cases -- Can be more aggressive -- Example: Text starting with "Complaint" is very likely a complaint - -### The Three Buckets Logic - -``` -EXCLUDED = does not match ANY minimal heuristic -NEUTRAL = matches at least one minimal BUT no likely heuristics -LIKELY = matches at least one minimal AND at least one likely heuristic -``` - -### Creating and Managing Heuristics - -```python -from clx.models import Label, LabelHeuristic - -# Get the label -label = Label.objects.get(project_id="docket-entry", name="Motion") - -# View existing heuristics -for h in label.heuristics.all(): - print(f"ID: {h.id}") - print(f" Query: {h.querystring or h.custom}") - print(f" is_minimal: {h.is_minimal}, is_likely: {h.is_likely}") - print(f" Matches: {h.num_examples}, Applied: {h.applied_at}") - -# Create a querystring heuristic -heuristic = LabelHeuristic.objects.create( - label=label, - querystring="motion|application|request", - is_minimal=True, -) - -# Apply the heuristic (computes across corpus) -heuristic.apply() - -# Create a likely heuristic -likely_heuristic = LabelHeuristic.objects.create( - label=label, - querystring="^motion", - is_likely=True, -) -likely_heuristic.apply() - -# Check bucket counts -label.refresh_from_db() -print(f"Excluded: {label.num_excluded}") -print(f"Neutral: {label.num_neutral}") -print(f"Likely: {label.num_likely}") -``` - -### Guidelines for Claude - -1. **Start simple** - Don't overthink. Simple keyword matches work well. - -2. **Minimal conditions should be conservative**: - - Ask: "Could there ever be a positive example that doesn't match this?" - - If yes, broaden the condition or add alternatives with `|` - - Include common misspellings, abbreviations, synonyms - -3. **Likely conditions can be aggressive**: - - These just identify easy cases, not all cases - - Prefix matches (`^term`) are often good likely conditions - -4. **Iterate based on counterexamples**: - - If you find a positive example in the "excluded" bucket → expand minimal condition - - If you find obvious positives in "neutral" → add likely conditions - -5. **Multiple heuristics combine with OR**: - - Multiple minimal heuristics: excluded if matches NONE of them - - Multiple likely heuristics: likely if matches ANY of them - ---- - -## Step 2: Create Annotation Decisions - -### Purpose - -Decisions are reason-annotated examples that define decision boundaries. They serve two purposes: -1. Document where we're drawing the line on edge cases -2. Provide training examples for the GEPA predictor optimization - -### What Makes a Good Decision - -- **Keep it minimal**: Humans should be able to review all decisions and understand the labeling policy -- **Include obvious examples**: At least one clear positive example ("This is obviously a complaint") -- **Focus on edge cases**: Where the boundary isn't obvious -- **Short reasons**: 1-2 sentences explaining why - -### Examples of Good Decisions - -For a "Complaint" label: -- **Positive**: "Complaint for Damages" → `True`, "This is clearly a complaint filing" -- **Negative**: "Submission of Complaint as Exhibit" → `False`, "This references a complaint but is not the complaint itself" -- **Negative**: "Response to Complaint" → `False`, "This is a response document, not the complaint" - -For a "Motion" label: -- **Positive**: "Motion for Summary Judgment" → `True`, "Standard motion filing" -- **Positive**: "Application for Extension of Time" → `True`, "Applications that request court action are functionally motions" -- **Negative**: "Opposition to Motion" → `False`, "This opposes a motion but is not itself a motion" - -### Creating Decisions - -```python -from clx.models import Label, LabelDecision -from clx import generate_hash - -label = Label.objects.get(project_id="docket-entry", name="Motion") - -# View existing decisions -for d in label.decisions.all(): - print(f"Value: {d.value}") - print(f"Text: {d.text[:100]}...") - print(f"Reason: {d.reason}") - print() - -# Create a decision from text -text = "Motion for Summary Judgment filed by Defendant" -decision = LabelDecision.objects.create( - label=label, - text_hash=generate_hash(text), - text=text, - value=True, - reason="Standard motion filing requesting summary judgment" -) - -# Or create from a search result (has text_hash already) -# See Search section for how to find examples -``` - -### Guidelines for Claude - -1. **Start with 1-2 obvious decisions** per label -2. **Add edge case decisions as you encounter them** during review -3. **Keep reasons brief but clear** - they'll be used for predictor training -4. **Update decisions if needed** - the same text_hash will update the existing decision - ---- - -## Step 3: Sample the Training Set - -### Purpose - -The training set is a diverse sample of examples used for: -- Running predictor inference -- Training fine-tuned BERT models -- Evaluating model performance - -### How Sampling Works - -The trainset samples from multiple sources to ensure diversity: - -1. **Heuristic buckets**: Random samples from excluded, neutral, and likely buckets -2. **Decision neighbors**: Semantic neighbors of each decision (finds similar edge cases) - -Default configuration (configurable per label): -- `trainset_num_excluded`: 1000 examples from excluded bucket -- `trainset_num_neutral`: 1000 examples from neutral bucket -- `trainset_num_likely`: 1000 examples from likely bucket -- `trainset_num_decision_neighbors`: 50 neighbors per decision - -The sampling uses "mesh sort" to select diverse examples (not just random). - -### Train vs Eval Split - -- **Train split**: Main sample (ratio=1.0) -- **Eval split**: Smaller sample (ratio=0.2) for evaluation - -### Updating the Trainset - -```python -from clx.models import Label - -label = Label.objects.get(project_id="docket-entry", name="Motion") - -# Configure sampling parameters (optional) -label.trainset_num_excluded = 1000 -label.trainset_num_neutral = 1000 -label.trainset_num_likely = 1000 -label.trainset_num_decision_neighbors = 50 -label.save() - -# Sample the trainset -label.update_trainset() - -# Check what was sampled -print(f"Train examples: {label.trainset_examples.filter(split='train').count()}") -print(f"Eval examples: {label.trainset_examples.filter(split='eval').count()}") -``` - -### When to Resample - -Resample the trainset when: -- You add new decisions (to include their neighbors) -- You change heuristics significantly -- You want different sampling parameters - -**Note**: Resampling will require re-running predictions (Step 4), which costs money. - ---- - -## Step 4: Fit and Run Predictor - -### Purpose - -The predictor is a small LLM (GPT-mini, Gemini Flash, etc.) that classifies examples. It uses GEPA (a DSPY optimization algorithm) to generate an optimized classification prompt based on your decisions. - -### Cost Warning - -Running predictions costs money (~$2-3 per full trainset run). Plan your workflow to minimize re-runs: -- Batch multiple decisions before resampling -- Fix as many issues as possible before re-running predictions -- The iteration loop is: decisions → resample → fit predictor → run predictions → review → repeat - -### Fitting the Predictor - -Fitting uses your decisions to optimize a classification prompt: - -```python -from clx.models import Label - -label = Label.objects.get(project_id="docket-entry", name="Motion") - -# Configure models (optional) -label.inference_model = "openai/gpt-5-mini" # For predictions -label.teacher_model = "openai/gpt-5" # For GEPA optimization -label.save() - -# Fit the predictor (uses decisions as training examples) -label.fit_predictor() -# This will print the cost when done -``` - -### Running Predictions - -After fitting, run predictions across the trainset: - -```python -label.update_trainset_preds(num_threads=128) - -# Check prediction counts -label.refresh_from_db() -print(f"Positive predictions: {label.trainset_num_positive_preds}") -print(f"Negative predictions: {label.trainset_num_negative_preds}") -``` - -### Viewing Predictions with Reasons - -The predictor outputs both a value and a reason for each prediction: - -```python -# View trainset examples with predictions -for ex in label.trainset_examples.filter(pred__isnull=False)[:10]: - print(f"Pred: {ex.pred}") - print(f"Text: {ex.text[:100]}...") - print(f"Reason: {ex.reason}") - print() -``` - ---- - -## Step 5: Train Fine-tuned Models - -### Purpose - -Fine-tuned BERT models are the production output. They're fast and cheap to run at scale. We train them on the predictor's outputs. - -### Two Configs - -- **`main`**: Full training (10 epochs) - the production model -- **`underfit`**: Light training (1 epoch) - useful for finding different failure modes - -Training both configs helps identify disagreements between models. - -### Training Process - -Training can be done via CLI or programmatically: - -```bash -# CLI: Train main model (10 epochs) -clx train docket-entry "Motion" main - -# CLI: Train underfit model (1 epoch) -clx train docket-entry "Motion" underfit -``` - -```python -# Programmatic: Train a specific config -label = Label.objects.get(project_id="docket-entry", name="Motion") -label.train_finetune("main") -``` - -The training process: -1. Prepares training data from the trainset -2. Runs training remotely in the cloud -3. Runs predictions on the trainset using the trained model -4. Updates the finetune tags and saves eval results - -### Update All (Recommended) - -The `update_all` method runs the full pipeline, but **only steps that are out of date**: - -```python -label = Label.objects.get(project_id="docket-entry", name="Motion") - -# Run only what's needed based on timestamps -label.update_all() - -# Force run everything regardless of timestamps -label.update_all(force=True) -``` - -This checks timestamps and runs: -1. **Resample trainset** - if decisions are newer than trainset -2. **Fit predictor** - if trainset is newer than predictor -3. **Run predictions** - if predictor is newer than predictions -4. **Train finetunes** - if predictions are newer than finetunes -5. **Run global corpus predictions** - only if `predict=True` and main finetune is newer than global predictions - -```python -# Also run global corpus predictions (step 5) -label.update_all(predict=True) -``` - -### Programmatic Access - -```python -from clx.models import Label - -label = Label.objects.get(project_id="docket-entry", name="Motion") - -# Prepare finetune data (for inspection) -train_data, eval_data, config = label.prepare_finetune("main") - -# Get the trained model pipeline (runs remotely) -pipe = label.get_finetune_run_pipe("main") -predictions = pipe(["some text to classify"], batch_size=16) - -# View finetune results -for ft in label.fintunes.all(): - print(f"Config: {ft.config_name}") - print(f"Results: {ft.eval_results}") -``` - -### Global Corpus Predictions - -After training a finetune, you can run predictions across the **entire corpus** (not just the trainset). This is a separate step because it's more expensive and not always needed during development. - -Global predictions only run for the **main finetune config** (defined on the search model as `main_finetune_config`). - -```python -# Run predictions across entire corpus using the main finetune config -label.predict_finetune() - -# Force restart (clears cache and starts fresh) -label.predict_finetune(force=True) -``` - -The `predict_finetune` method: -1. **Uses the main finetune config** - defined on the search model (e.g., `DocketEntry.main_finetune_config = "main"`) -2. **Is idempotent** - caches progress and picks up where it left off if interrupted -3. **Uses a CSV cache** in the label's data directory (`data_dir/finetune_predictions_cache.csv`) -4. **Updates the global finetune tag** (`ft`) when complete -5. **Sets `predicted_at` timestamp** on the LabelFinetune object -6. **Deletes the cache** after successful completion - -**Tags**: -- `trainset:ft:{config}` - Predictions on trainset only (set by `train_finetune`) -- `ft` - Predictions on entire corpus for the main config (set by `predict_finetune`) - -**Timestamps** on LabelFinetune: -- `finetuned_at` - When the model was last trained -- `predicted_at` - When global corpus predictions were last run - ---- - -## Step 6: Review and Iterate - -### Finding Issues - -Use search to find examples where models might be wrong: - -1. **Review disagreements**: Examples where predictor and fine-tunes disagree -2. **Search by heuristic bucket**: Look in neutral bucket for edge cases -3. **Keyword search**: Find specific patterns -4. **Semantic search**: Find examples similar to a known problem - -### Fast Annotations - -For quick fixes without full decision reasons, use fast annotations: - -```python -from clx.models import Label - -label = Label.objects.get(project_id="docket-entry", name="Motion") -model = label.project.get_search_model() - -# Get an example -example = model.objects.get(id=12345) - -# Set annotation (no reason needed) -example.set_annotation(label, True) # Mark as positive -example.set_annotation(label, False) # Mark as negative -example.set_annotation(label, "flag") # Flag for exclusion from trainset -example.set_annotation(label, None) # Clear annotation -``` - -Flagged examples are excluded from the trainset entirely. - -### The Iteration Loop - -1. **Search** for potential issues (disagreements, specific patterns, etc.) -2. **Review** examples and identify errors -3. **Fix** via decisions (for edge cases needing reasons) or fast annotations (for quick fixes) -4. **Batch fixes** - do as many as possible before re-running -5. **Resample trainset** (`label.update_trainset()`) -6. **Refit predictor** (`label.fit_predictor()`) -7. **Re-run predictions** (`label.update_trainset_preds()`) -8. **Retrain models** (CLI train commands) -9. **Repeat** - ---- - -## Search Reference - -The search system is the primary way to find and review examples. - -### Basic Search - -```python -from clx.models import Project - -project = Project.objects.get(id="docket-entry") -model = project.get_search_model() - -# Simple search - returns dict with 'data' key -results = model.objects.search(page=1, page_size=100) -for item in results["data"]: - print(item["id"], item["text"][:80]) -``` - -### Search Parameters - -All parameters go in a `params` dict: - -```python -results = model.objects.search( - active_label_id=label.id, # Required for most filters - params={ - # Heuristic bucket filter - "heuristic_bucket": "excluded" | "neutral" | "likely", - - # Trainset filter - "trainset_split": "train" | "eval" | "both", - - # Predictor prediction filter - "predictor_value": "true" | "false", - - # Manual annotation filter - "annotation_value": "true" | "false" | "flag" | "any" | "none", - - # Find disagreements between models - "review_disagreements": True, - - # Keyword search (uses query string syntax) - "querystring": "motion, ~denied", - }, - page=1, - page_size=100, -) -``` - -### Semantic Search - -Find examples similar to a query or embedding: - -```python -# Search by text similarity -results = model.objects.search( - semantic_sort="motion for summary judgment", - page_size=50, -) - -# Or use an embedding directly -embedding = [0.1, 0.2, ...] # 96-dim vector -results = model.objects.search(semantic_sort=embedding) -``` - -### Search Result Format - -Each result includes: - -```python -{ - "id": 12345, - "text_hash": "abc123...", - "text": "Full text of the example", - "tags": [1, 5, 12], # Tag IDs - # If in trainset: - "split": "train" | "eval", - "pred": True | False | None, - "reason": "Predictor's reasoning...", -} -``` - -### Count Only - -```python -result = model.objects.search( - active_label_id=label.id, - params={"heuristic_bucket": "neutral"}, - count=True, -) -print(f"Total: {result['total']}") -``` - -### Query String Syntax (Review) - -| Operator | Meaning | Example | -|----------|---------|---------| -| `,` | AND | `motion, court` | -| `\|` | OR | `motion\|filing` | -| `~` | NOT | `~denied` | -| `^` | Starts with | `^Summary` | - ---- - -## Key Files Reference - -| Component | File | Key Lines | -|-----------|------|-----------| -| Models | `clx/app/models.py` | Full file | -| Search | `clx/app/search_utils.py` | `SearchQuerySet.search` | -| Heuristics | `clx/app/models.py` | `LabelHeuristic` class | -| Custom Heuristics | `clx/app/custom_heuristics.py` | Decorator pattern | -| Train CLI | `clx/cli/train.py` | Export/train/import | -| Views | `clx/app/views.py` | All endpoints | -| **Helpers** | `experiment/helpers.py` | Claude Code utilities | - ---- - -## Helper Scripts for Claude Code - -The `experiment/helpers.py` module provides convenient functions for the annotation workflow: - -### Quick Status Check - -```python -from experiment.helpers import print_label_status - -print_label_status("Motion") -``` - -### Searching and Viewing Examples - -```python -from experiment.helpers import ( - search_examples, - print_examples, - disagreements, - neutral_examples, - similar_to, -) - -# Find disagreements between models -examples = disagreements("Motion") -print_examples(examples) - -# Look at edge cases (neutral bucket) -examples = neutral_examples("Motion", page_size=10) -print_examples(examples) - -# Find similar examples -examples = similar_to("Motion", "application for extension of time") -print_examples(examples, show_full_text=True) - -# Complex search -examples = search_examples( - "Motion", - heuristic_bucket="neutral", - querystring="application", - page_size=20, -) -print_examples(examples) -``` - -### Creating Decisions - -```python -from experiment.helpers import ( - create_decision, - create_decision_from_id, - view_decisions, -) - -# View existing decisions -view_decisions("Motion") - -# Create from text -create_decision( - "Motion", - text="Application for Extension of Time", - value=True, - reason="Applications requesting court action are functionally motions" -) - -# Create from search result ID -create_decision_from_id( - "Motion", - example_id=12345, - value=False, - reason="This is a response to a motion, not a motion itself" -) -``` - -### Fast Annotations - -```python -from experiment.helpers import annotate - -annotate("Motion", example_id=12345, value=True) # Positive -annotate("Motion", example_id=12346, value=False) # Negative -annotate("Motion", example_id=12347, value="flag") # Exclude from trainset -``` - -### Creating Heuristics - -```python -from experiment.helpers import create_heuristic - -create_heuristic( - "Motion", - querystring="motion|application|request", - is_minimal=True, - apply=True, # Immediately computes across corpus -) -``` - ---- - -## Scales OKN Integration (docket-entry only) - -For the docket-entry project, we have predictions from Scales OKN—a similar classification project with pre-trained models for many of the same labels. - -### Available Scales Labels - -The following labels have Scales OKN predictions imported: - -| Scales Label | Our Label | -|--------------|-----------| -| summons | Summons | -| waiver | Waiver | -| brief | Brief / Memorandum | -| arrest | Arrest | -| warrant | Warrant | -| verdict | Verdict | -| answer | Answer | -| complaint | Complaint | -| indictment | Indictment | -| information | Information | -| petition | Petition | -| notice | Notice | -| response | Reply / Response | -| minute entry | Minute Entry | -| plea agreement | Plea Agreement | -| judgment | Judgment | -| stipulation | Stipulation | -| motion | Motion | -| order | Order | - -### How Scales Tags Work - -- Each label has a `LabelTag` with `name="scales"` -- Positive Scales predictions (score > 0.5) are tagged -- Absence of tag means Scales predicted negative (or no prediction) - -### Using Scales for Review - -Scales predictions are another source of feedback when reviewing. You can compare: -- Examples where our models predict TRUE but Scales predicts FALSE -- Examples where our models predict FALSE but Scales predicts TRUE - -**Important caveats:** -1. **Scales is not ground truth** - it has errors and may make different annotation decisions -2. **Scope to trainset** - we only compute our predictions on the trainset, so compare within trainset -3. **Check against decisions** - if our models disagree with Scales but are consistent with our documented decisions, that's fine - -### Searching with Scales - -```python -from experiment.helpers import search_examples - -# Find examples where we predict TRUE but Scales predicts FALSE -# (These might be cases Scales missed, or cases we're wrong about) -label = get_label("Motion") -scales_tag = label.labeltag_set.filter(name="scales").first() - -# Search for trainset examples our predictor says TRUE -examples = search_examples( - "Motion", - trainset_split="train", - predictor_value="true", -) - -# Filter to those without scales tag (Scales said FALSE) -# This requires checking tags manually or using raw search -``` - -### When to Use Scales Feedback - -- **After initial model training** - to find potential blind spots -- **When reviewing disagreements** - as an additional signal -- **NOT as automatic corrections** - always review why there's a disagreement - ---- - -## Notes - -- The docket-entry project uses `DocketEntry` as the search model -- Heuristics create `LabelTag` entries attached to documents via PostgreSQL array fields -- The `apply()` step processes documents in batches of 1M for efficiency -- Predictions cost money - batch your changes before re-running -- The main fine-tune model is the production output; underfit helps find different errors diff --git a/cc/helpers.py b/cc/helpers.py deleted file mode 100644 index 1bb3e9e..0000000 --- a/cc/helpers.py +++ /dev/null @@ -1,601 +0,0 @@ -""" -Helper functions for Claude Code to interact with the CLX annotation workflow. - -These utilities make it easier to: -- Search and view examples with predictions -- Create decisions and annotations -- Check label status -""" - -from clx import generate_hash -from clx.models import Label, LabelDecision, LabelHeuristic, Project - - -def get_label(label_name: str, project_id: str = "docket-entry") -> Label: - """Get a label by name.""" - return Label.objects.get(project_id=project_id, name=label_name) - - -def get_project(project_id: str = "docket-entry") -> Project: - """Get a project by ID.""" - return Project.objects.get(id=project_id) - - -def label_status(label_name: str, project_id: str = "docket-entry") -> dict: - """Get comprehensive status of a label including warnings and decisions.""" - label = get_label(label_name, project_id) - - # Get all decisions with full info - decisions = [ - { - "id": d.id, - "value": d.value, - "reason": d.reason, - "text": d.text, - "created_at": d.created_at, - "updated_at": d.updated_at, - } - for d in label.decisions.all().order_by("-updated_at") - ] - - # Get finetunes with timestamps - finetunes = [ - { - "config": ft.config_name, - "results": ft.eval_results, - "created_at": ft.created_at, - "updated_at": ft.updated_at, - "finetuned_at": ft.finetuned_at, - "predicted_at": ft.predicted_at, - "is_main": ft.config_name - == label.project.get_search_model().main_finetune_config, - } - for ft in label.fintunes.all() - ] - - # Generate warnings based on timestamp comparisons - warnings = [] - - # Get latest decision timestamp - latest_decision_at = None - if decisions: - latest_decision_at = max(d["updated_at"] for d in decisions) - - # Warning: decisions newer than trainset - if latest_decision_at and label.trainset_updated_at: - if latest_decision_at > label.trainset_updated_at: - warnings.append( - "Decisions updated since last trainset sampling - consider resampling" - ) - elif latest_decision_at and not label.trainset_updated_at: - warnings.append("Trainset has never been sampled") - - # Warning: trainset newer than predictor - if label.trainset_updated_at and label.predictor_updated_at: - if label.trainset_updated_at > label.predictor_updated_at: - warnings.append( - "Trainset updated since last predictor fit - consider refitting" - ) - elif label.trainset_updated_at and not label.predictor_updated_at: - warnings.append("Predictor has never been fit") - - # Warning: predictor newer than predictions - if label.predictor_updated_at and label.trainset_predictions_updated_at: - if label.predictor_updated_at > label.trainset_predictions_updated_at: - warnings.append( - "Predictor updated since last prediction run - consider rerunning predictions" - ) - elif ( - label.predictor_updated_at - and not label.trainset_predictions_updated_at - ): - warnings.append("Predictions have never been run") - - # Warning: predictions newer than finetunes - if label.trainset_predictions_updated_at and finetunes: - for ft in finetunes: - finetuned_at = ft.get("finetuned_at") - if ( - not finetuned_at - or label.trainset_predictions_updated_at > finetuned_at - ): - warnings.append( - f"Predictions updated since '{ft['config']}' finetune - consider retraining" - ) - - # Warning: finetunes newer than global predictions - for ft in finetunes: - finetuned_at = ft.get("finetuned_at") - predicted_at = ft.get("predicted_at") - if finetuned_at and (not predicted_at or finetuned_at > predicted_at): - warnings.append( - f"'{ft['config']}' finetune updated since global predictions - consider running predict_finetune" - ) - - return { - "name": label.name, - "id": label.id, - "warnings": warnings, - "heuristic_buckets": { - "excluded": label.num_excluded, - "neutral": label.num_neutral, - "likely": label.num_likely, - }, - "heuristics": [ - { - "id": h.id, - "query": h.querystring or f"[custom: {h.custom}]", - "is_minimal": h.is_minimal, - "is_likely": h.is_likely, - "num_examples": h.num_examples, - "applied_at": h.applied_at, - } - for h in label.heuristics.all() - ], - "decisions": decisions, - "trainset": { - "train": label.trainset_examples.filter(split="train").count(), - "eval": label.trainset_examples.filter(split="eval").count(), - "updated_at": label.trainset_updated_at, - }, - "predictor": { - "positive_preds": label.trainset_num_positive_preds, - "negative_preds": label.trainset_num_negative_preds, - "updated_at": label.trainset_predictions_updated_at, - "fitted_at": label.predictor_updated_at, - "inference_model": label.inference_model, - "teacher_model": label.teacher_model, - }, - "finetunes": finetunes, - } - - -def print_label_status(label_name: str, project_id: str = "docket-entry"): - """Print a formatted label status report.""" - status = label_status(label_name, project_id) - - print(f"=== Label: {status['name']} (ID: {status['id']}) ===\n") - - # Show warnings prominently at the top - if status["warnings"]: - print("WARNINGS:") - for warning in status["warnings"]: - print(f" ! {warning}") - print() - - print("Heuristic Buckets:") - for bucket, count in status["heuristic_buckets"].items(): - print(f" {bucket}: {count:,}") - - print(f"\nHeuristics ({len(status['heuristics'])}):") - for h in status["heuristics"]: - flags = [] - if h["is_minimal"]: - flags.append("minimal") - if h["is_likely"]: - flags.append("likely") - flag_str = f" [{', '.join(flags)}]" if flags else "" - print(f" {h['query']}{flag_str} → {h['num_examples']:,} matches") - - print(f"\nDecisions ({len(status['decisions'])}):") - for d in status["decisions"]: - value_str = "TRUE" if d["value"] else "FALSE" - text_preview = ( - d["text"][:80] + "..." - if d["text"] and len(d["text"]) > 80 - else d["text"] - ) - print(f" [{value_str}] {text_preview}") - print(f" Reason: {d['reason']}") - - print("\nTrainset:") - print(f" Train: {status['trainset']['train']:,}") - print(f" Eval: {status['trainset']['eval']:,}") - print(f" Updated: {status['trainset']['updated_at']}") - - print("\nPredictor:") - print(f" Positive preds: {status['predictor']['positive_preds']:,}") - print(f" Negative preds: {status['predictor']['negative_preds']:,}") - print(f" Fitted: {status['predictor']['fitted_at']}") - print(f" Predictions updated: {status['predictor']['updated_at']}") - - if status["finetunes"]: - print("\nFinetunes:") - for ft in status["finetunes"]: - print(f" {ft['config']}:") - print(f" Results: {ft['results']}") - print(f" Finetuned at: {ft['finetuned_at']}") - print(f" Global predictions at: {ft['predicted_at']}") - - -def search_examples( - label_name: str, - project_id: str = "docket-entry", - heuristic_bucket: str | None = None, - trainset_split: str | None = None, - predictor_value: str | None = None, - annotation_value: str | None = None, - review_disagreements: bool = False, - querystring: str | None = None, - semantic_sort: str | None = None, - page: int = 1, - page_size: int = 20, -) -> list[dict]: - """ - Search for examples with full context. - - Returns examples with: - - text and metadata - - predictor prediction and reason (if in trainset) - - finetune predictions - - annotation status - """ - label = get_label(label_name, project_id) - model = label.project.get_search_model() - - params = {} - if heuristic_bucket: - params["heuristic_bucket"] = heuristic_bucket - if trainset_split: - params["trainset_split"] = trainset_split - if predictor_value: - params["predictor_value"] = predictor_value - if annotation_value: - params["annotation_value"] = annotation_value - if review_disagreements: - params["review_disagreements"] = True - if querystring: - params["querystring"] = querystring - - search_kwargs = { - "active_label_id": label.id, - "params": params, - "page": page, - "page_size": page_size, - } - if semantic_sort: - search_kwargs["semantic_sort"] = semantic_sort - - results = model.objects.search(**search_kwargs) - - # Enrich with tag information - enriched = [] - for item in results.get("data", []): - tags = item.get("tags", []) - - # Check annotation status - anno_status = None - if label.anno_true_tag.id in tags: - anno_status = "true" - elif label.anno_false_tag.id in tags: - anno_status = "false" - elif label.anno_flag_tag.id in tags: - anno_status = "flag" - - # Check finetune predictions - finetune_preds = {} - for ft in label.fintunes.all(): - ft_tag = label.get_trainset_finetune_tag(ft.config_name) - finetune_preds[ft.config_name] = ft_tag.id in tags - - # Check predictor prediction - predictor_pred = label.trainset_pred_tag.id in tags - - enriched.append( - { - "id": item["id"], - "text_hash": item["text_hash"], - "text": item["text"], - "annotation": anno_status, - "predictor_pred": predictor_pred - if item.get("split") - else None, - "predictor_reason": item.get("reason"), - "finetune_preds": finetune_preds if finetune_preds else None, - "trainset_split": item.get("split"), - } - ) - - return enriched - - -def print_examples( - examples: list[dict], - show_full_text: bool = False, - max_text_len: int = 120, -): - """Print examples in a readable format.""" - for i, ex in enumerate(examples, 1): - print(f"\n{'=' * 60}") - print(f"[{i}] ID: {ex['id']}") - - text = ex["text"] - if not show_full_text and len(text) > max_text_len: - text = text[:max_text_len] + "..." - print(f"Text: {text}") - - # Predictions - preds = [] - if ex.get("predictor_pred") is not None: - preds.append(f"predictor={ex['predictor_pred']}") - if ex.get("finetune_preds"): - for config, pred in ex["finetune_preds"].items(): - preds.append(f"{config}={pred}") - if preds: - print(f"Predictions: {', '.join(preds)}") - - if ex.get("predictor_reason"): - print(f"Reason: {ex['predictor_reason']}") - - if ex.get("annotation"): - print(f"Annotation: {ex['annotation']}") - - if ex.get("trainset_split"): - print(f"Split: {ex['trainset_split']}") - - -def view_decisions(label_name: str, project_id: str = "docket-entry"): - """View all decisions for a label.""" - label = get_label(label_name, project_id) - - print(f"=== Decisions for {label_name} ===\n") - - for d in label.decisions.all().order_by("-updated_at"): - value_str = "TRUE" if d.value else "FALSE" - print(f"[{value_str}] {d.text[:100]}...") - print(f" Reason: {d.reason}") - print() - - -def create_decision( - label_name: str, - example_id: int, - value: bool, - reason: str, - project_id: str = "docket-entry", -) -> LabelDecision: - """ - Create a decision from an example ID. - - IMPORTANT: Always use example IDs from search results, not raw text. - This ensures the text_hash matches documents in the search table. - """ - label = get_label(label_name, project_id) - model = label.project.get_search_model() - example = model.objects.get(id=example_id) - - text = example.text - text_hash = generate_hash(text) - - decision, created = LabelDecision.objects.update_or_create( - label=label, - text_hash=text_hash, - defaults={ - "text": text, - "value": value, - "reason": reason, - }, - ) - - action = "Created" if created else "Updated" - print(f"{action} decision: {value} - {reason}") - return decision - - -def annotate( - label_name: str, - example_id: int, - value: bool | str | None, - project_id: str = "docket-entry", -): - """ - Set a fast annotation on an example. - - value can be: - - True: positive - - False: negative - - "flag": exclude from trainset - - None: clear annotation - """ - label = get_label(label_name, project_id) - model = label.project.get_search_model() - example = model.objects.get(id=example_id) - example.set_annotation(label, value) - print(f"Set annotation {value} on example {example_id}") - - -def create_heuristic( - label_name: str, - querystring: str, - is_minimal: bool = False, - is_likely: bool = False, - apply: bool = False, - project_id: str = "docket-entry", -) -> LabelHeuristic: - """Create and optionally apply a heuristic. Defaults to NOT applying.""" - label = get_label(label_name, project_id) - - heuristic = LabelHeuristic.objects.create( - label=label, - querystring=querystring, - is_minimal=is_minimal, - is_likely=is_likely, - ) - - if apply: - print(f"Applying heuristic: {querystring}") - heuristic.apply() - label.refresh_from_db() - print(f"Matches: {heuristic.num_examples:,}") - print( - f"Buckets - Excluded: {label.num_excluded:,}, Neutral: {label.num_neutral:,}, Likely: {label.num_likely:,}" - ) - - return heuristic - - -# Quick aliases for common operations -def disagreements(label_name: str, page_size: int = 20, **kwargs): - """Find examples where models disagree.""" - return search_examples( - label_name, - review_disagreements=True, - page_size=page_size, - **kwargs, - ) - - -def neutral_examples(label_name: str, page_size: int = 20, **kwargs): - """Get examples from the neutral bucket (edge cases).""" - return search_examples( - label_name, - heuristic_bucket="neutral", - page_size=page_size, - **kwargs, - ) - - -def likely_examples(label_name: str, page_size: int = 20, **kwargs): - """Get examples from the likely bucket (probable positives).""" - return search_examples( - label_name, - heuristic_bucket="likely", - page_size=page_size, - **kwargs, - ) - - -def excluded_examples(label_name: str, page_size: int = 20, **kwargs): - """Get examples from the excluded bucket (probable negatives).""" - return search_examples( - label_name, - heuristic_bucket="excluded", - page_size=page_size, - **kwargs, - ) - - -def similar_to(label_name: str, text: str, page_size: int = 20, **kwargs): - """Find examples semantically similar to the given text.""" - return search_examples( - label_name, - semantic_sort=text, - page_size=page_size, - **kwargs, - ) - - -def set_instructions( - label_name: str, instructions: str, project_id: str = "docket-entry" -): - """Set the instructions for a label.""" - label = get_label(label_name, project_id) - label.instructions = instructions - label.save() - print(f"Instructions saved for {label_name}") - - -def view_heuristics(label_name: str, project_id: str = "docket-entry"): - """View all heuristics for a label.""" - label = get_label(label_name, project_id) - print(f"=== Heuristics for {label_name} ===\n") - for h in label.heuristics.all(): - flags = [] - if h.is_minimal: - flags.append("minimal") - if h.is_likely: - flags.append("likely") - flag_str = f"[{', '.join(flags)}]" if flags else "[none]" - print( - f" {flag_str} {h.querystring or f'[custom: {h.custom}]'} -> {h.num_examples:,} matches" - ) - - -def count_matches(querystring: str, project_id: str = "docket-entry") -> int: - """Count how many examples match a querystring without creating a heuristic.""" - project = get_project(project_id) - model = project.get_search_model() - result = model.objects.search( - params={"querystring": querystring}, - count=True, - ) - return result.get("total", 0) - - -def search_by_query( - querystring: str, - project_id: str = "docket-entry", - page: int = 1, - page_size: int = 20, -) -> list[dict]: - """Search examples by querystring (without needing a label).""" - project = get_project(project_id) - model = project.get_search_model() - results = model.objects.search( - params={"querystring": querystring}, - page=page, - page_size=page_size, - ) - return [ - { - "id": item["id"], - "text": item["text"], - "text_hash": item["text_hash"], - } - for item in results.get("data", []) - ] - - -def print_search_results(results: list[dict], max_text_len: int = 120): - """Print search results in a compact format.""" - for i, r in enumerate(results, 1): - text = r["text"] - if len(text) > max_text_len: - text = text[:max_text_len] + "..." - print(f"[{i}] (id={r['id']}) {text}") - - -def setup_label( - label_name: str, - minimal_query: str, - likely_query: str, - instructions: str, - project_id: str = "docket-entry", - apply: bool = False, -): - """Set up a label with minimal heuristic, likely heuristic, and instructions. Defaults to NOT applying.""" - label = get_label(label_name, project_id) - - # Set instructions - label.instructions = instructions - label.save() - print(f"Instructions saved for {label_name}") - - # Create minimal heuristic - minimal_h = LabelHeuristic.objects.create( - label=label, - querystring=minimal_query, - is_minimal=True, - ) - print(f"Created minimal heuristic: {minimal_query}") - - # Create likely heuristic - likely_h = LabelHeuristic.objects.create( - label=label, - querystring=likely_query, - is_likely=True, - ) - print(f"Created likely heuristic: {likely_query}") - - if apply: - print("Applying heuristics...") - minimal_h.apply() - likely_h.apply() - label.refresh_from_db() - print(f"Minimal matches: {minimal_h.num_examples:,}") - print(f"Likely matches: {likely_h.num_examples:,}") - print( - f"Buckets - Excluded: {label.num_excluded:,}, Neutral: {label.num_neutral:,}, Likely: {label.num_likely:,}" - ) - - return label diff --git a/clx/__init__.py b/clx/__init__.py index 36613ae..0eeb93a 100644 --- a/clx/__init__.py +++ b/clx/__init__.py @@ -1,14 +1,7 @@ # flake8: noqa: E402 -from pathlib import Path - -import simplejson as json from dotenv import load_dotenv -config_path = Path.home() / ".cache" / "clx" / "config.json" -if config_path.exists(): - config = json.loads(config_path.read_text()) - if config.get("autoload-env"): - load_dotenv(override=False) +load_dotenv() from .utils import ( S3, diff --git a/clx/app/README.md b/clx/app/README.md deleted file mode 100644 index a2a1260..0000000 --- a/clx/app/README.md +++ /dev/null @@ -1,29 +0,0 @@ -# Docket Viewer Application - -## Running the server - -You'll need to use Postgres with pgvector installed. If you want to use Docker, you can use the following command: - -```bash -docker run -d \ - --name pgvector \ - --env-file .env \ - -p 5432:5432 \ - -v pgdata:/var/lib/postgresql/data \ - pgvector/pgvector:pg16 -``` - -You can run Django management commands with `clx manage`. - -For example, to run the development server: - -```bash -clx manage runserver -``` - -If you are starting from scratch, you can initialize the database with: - -```bash -clx manage makemigrations app -clx manage migrate -``` diff --git a/clx/app/agent.py b/clx/app/agent.py new file mode 100644 index 0000000..9acf70c --- /dev/null +++ b/clx/app/agent.py @@ -0,0 +1,260 @@ +import logging + +import litellm +from django.db.models import Sum +from shortuuid import uuid + +from clx.app.models import Message +from clx.app.tools import ( + AddTrainingExamples, + Annotate, + AskUser, + ClearToolHistory, + CompactMemory, + Search, + UpdateLabelInstructions, + UpdateProjectInstructions, +) +from clx.llm.agent import Agent, message_tokens + +logger = logging.getLogger("clx.autopilot") + +SYSTEM_PROMPT_TEMPLATE = """\ +You are an assistant for the project "{project_name}". +You are working on the label "{label_name}". + +You have access to tools for searching project documents and updating +instructions. Use them when the user asks you to find data, refine +instructions, or configure the project. + +{project_instructions_block}\ +{label_instructions_block}\ +""" + +PROJECT_INSTRUCTIONS_BLOCK = """\ +## Project Instructions +{project_instructions} + +""" + +LABEL_INSTRUCTIONS_BLOCK = """\ +## Label Instructions +{label_instructions} + +""" + + +class CLXAgent(Agent): + """A thread-backed agent that persists messages to the DB.""" + + default_tools = [ + Search, + AddTrainingExamples, + Annotate, + CompactMemory, + ClearToolHistory, + UpdateLabelInstructions, + UpdateProjectInstructions, + AskUser, + ] + + _internal_fields = {"name", "args", "hidden"} + + def __init__(self, thread, **kwargs): + self.thread = thread + label = thread.label + project = label.project + + # Build system prompt dynamically (never saved to DB). + project_instructions = project.instructions.strip() + label_instructions = label.instructions.strip() + + project_block = ( + PROJECT_INSTRUCTIONS_BLOCK.format( + project_instructions=project_instructions + ) + if project_instructions + else "" + ) + label_block = ( + LABEL_INSTRUCTIONS_BLOCK.format( + label_instructions=label_instructions + ) + if label_instructions + else "" + ) + + system_prompt = SYSTEM_PROMPT_TEMPLATE.format( + project_name=project.name, + label_name=label.name, + project_instructions_block=project_block, + label_instructions_block=label_block, + ) + + # Find the last compact message (if any) and load from there. + compact_msg = ( + thread.messages.filter(is_compact=True) + .order_by("-created_at") + .first() + ) + + if compact_msg: + summary = compact_msg.data.get("content", "") + db_rows = list( + thread.messages.filter(created_at__gt=compact_msg.created_at) + .order_by("created_at") + .values_list("data", "hidden") + ) + db_messages = [ + {**data, "hidden": True} if hidden else data + for data, hidden in db_rows + ] + messages = [ + {"role": "system", "content": system_prompt}, + { + "role": "user", + "content": f"[Prior conversation summary]\n{summary}", + }, + ] + db_messages + # +2 for system prompt and synthetic summary message + self._persisted_count = len(db_messages) + 2 + else: + db_rows = list( + thread.messages.order_by("created_at").values_list( + "data", "hidden" + ) + ) + db_messages = [ + {**data, "hidden": True} if hidden else data + for data, hidden in db_rows + ] + messages = [ + {"role": "system", "content": system_prompt} + ] + db_messages + self._persisted_count = len(db_messages) + 1 + + super().__init__( + model=thread.model, + messages=messages, + state=thread.state or {}, + **kwargs, + ) + + @property + def sanitized_messages(self): + """Strip internal fields and summarize hidden messages.""" + result = [] + for message in self.messages: + cleaned = { + k: v + for k, v in message.items() + if k not in self._internal_fields + } + if message.get("hidden"): + if message.get("tool_calls"): + # Keep tool names but strip arguments. + cleaned["tool_calls"] = [ + { + **tc, + "function": { + **tc["function"], + "arguments": "{}", + }, + } + for tc in cleaned["tool_calls"] + ] + cleaned.pop("content", None) + elif message.get("role") == "tool": + cleaned["content"] = "[Removed to preserve context]" + else: + continue + result.append(cleaned) + return result + + def active_token_count(self): + """Token count from last compact point onward, excluding hidden.""" + compact_msg = ( + self.thread.messages.filter(is_compact=True) + .order_by("-created_at") + .values_list("created_at", flat=True) + .first() + ) + qs = self.thread.messages.filter(hidden=False) + if compact_msg: + qs = qs.filter(created_at__gte=compact_msg) + return qs.aggregate(total=Sum("num_tokens"))["total"] or 0 + + COMPACT_THRESHOLD = 25_000 + COMPACT_MSG = ( + "Compact your memory now. Write a detailed summary of the full " + "conversation so far. Make sure to keep track of your current " + "task instructions and progress in your compaction summary." + ) + + def compact_if_needed(self): + """Compact conversation if token count exceeds threshold.""" + if self.active_token_count() > self.COMPACT_THRESHOLD: + logger.info("Token count exceeds 25k, compacting...") + self.run(self.COMPACT_MSG) + + def autopilot_run(self, message): + """Run agent in autopilot mode. + + Returns 'completed' or 'awaiting_input'. + """ + # Run steps until CompleteTask is called or the turn ends. + for _ in range(self.max_steps): + response = self.step(message, call_tools=True) + message = None + + if response.get("tool_calls"): + tool_names = { + tc["function"]["name"] for tc in response["tool_calls"] + } + if "CompleteTask" in tool_names: + return "completed" + else: + return "awaiting_input" + + return "awaiting_input" + + def on_step(self, response_message): + """Save any new messages to the database.""" + new_messages = self.messages[self._persisted_count :] + if not new_messages: + return + objects = [ + Message( + id=uuid(), + thread=self.thread, + data=msg, + num_tokens=message_tokens(msg), + is_compact=( + msg.get("role") == "tool" + and msg.get("name") == "CompactMemory" + ), + hidden=msg.get("hidden", False), + ) + for msg in new_messages + ] + Message.objects.bulk_create(objects) + self._persisted_count = len(self.messages) + + # Accumulate cost on the thread. + if self.r and self.r.usage: + try: + self.thread.total_cost += litellm.completion_cost( + completion_response=self.r + ) + except Exception: + pass + + # Persist agent state and cost back to the thread. + self.thread.state = self.state + self.thread.save( + update_fields=[ + "state", + "total_cost", + "updated_at", + ] + ) diff --git a/clx/app/apps.py b/clx/app/apps.py index 85520bb..8d6969a 100644 --- a/clx/app/apps.py +++ b/clx/app/apps.py @@ -1,11 +1,5 @@ from django.apps import AppConfig -from django.db.models.signals import post_migrate class AppConfig(AppConfig): name = "clx.app" - - def ready(self): - from .search_utils import init_search_models - - post_migrate.connect(init_search_models) diff --git a/clx/app/custom_heuristics.py b/clx/app/custom_heuristics.py deleted file mode 100644 index a89e329..0000000 --- a/clx/app/custom_heuristics.py +++ /dev/null @@ -1,23 +0,0 @@ -custom_heuristics = {} - - -def custom_heuristic(project_id, label_name): - def decorator(f): - custom_heuristics[f.__name__] = { - "project_id": project_id, - "label_name": label_name, - "apply_fn": f, - } - return f - - return decorator - - -def within_first(text, term, n): - first_n = " ".join(text.split()[:n]) - return term in first_n - - -@custom_heuristic("docket-entry", "Motion") -def first_3_motion(text, **kwargs): - return within_first(text.lower(), "motion", 3) diff --git a/clx/app/migrations/0000_custom.py b/clx/app/migrations/0000_custom.py deleted file mode 100644 index 2815d95..0000000 --- a/clx/app/migrations/0000_custom.py +++ /dev/null @@ -1,15 +0,0 @@ -from django.conf import settings -from django.contrib.postgres.operations import TrigramExtension -from django.db import migrations -from pgvector.django import VectorExtension - - -class Migration(migrations.Migration): - - dependencies = [] - - operations = [ - TrigramExtension(), - VectorExtension(), - migrations.RunSQL(f"ALTER DATABASE {settings.DATABASES['default']['NAME']} SET hnsw.ef_search = 200;"), - ] diff --git a/clx/app/migrations/0001_initial.py b/clx/app/migrations/0001_initial.py index d1ac223..6d16499 100644 --- a/clx/app/migrations/0001_initial.py +++ b/clx/app/migrations/0001_initial.py @@ -1,10 +1,9 @@ -# Generated by Django 5.2.7 on 2025-12-11 15:01 +# Generated by Django 5.2.7 on 2026-04-06 19:55 -import django.contrib.postgres.fields import django.contrib.postgres.indexes import django.db.models.deletion -import pgvector.django.halfvec -import pgvector.django.indexes +import django.utils.timezone +import django_shortuuid.fields from django.db import migrations, models @@ -13,244 +12,37 @@ class Migration(migrations.Migration): initial = True dependencies = [ - ('app', '0000_custom'), ] operations = [ - migrations.CreateModel( - name='DocketEntry', - fields=[ - ('created_at', models.DateTimeField(auto_now_add=True)), - ('updated_at', models.DateTimeField(auto_now=True)), - ('text', models.TextField()), - ('text_prefix', models.CharField(max_length=50)), - ('text_hash', models.CharField(max_length=255)), - ('shuffle_sort', models.IntegerField()), - ('embedding', pgvector.django.halfvec.HalfVectorField(dimensions=96)), - ('id', models.BigIntegerField(primary_key=True, serialize=False)), - ('recap_id', models.BigIntegerField(unique=True)), - ('docket_id', models.BigIntegerField()), - ('entry_number', models.BigIntegerField(blank=True, null=True)), - ('date_filed', models.DateField(blank=True, null=True)), - ], - options={ - 'db_table': 'project_docket-entry_doc', - }, - ), - migrations.CreateModel( - name='DocketEntryShort', - fields=[ - ('created_at', models.DateTimeField(auto_now_add=True)), - ('updated_at', models.DateTimeField(auto_now=True)), - ('id', models.BigIntegerField(primary_key=True, serialize=False)), - ('text_prefix', models.CharField(max_length=50)), - ('text_hash', models.CharField(max_length=255)), - ('shuffle_sort', models.IntegerField()), - ('embedding', pgvector.django.halfvec.HalfVectorField(dimensions=96)), - ('text', models.TextField(unique=True)), - ('text_type', models.CharField(choices=[('short_description', 'Short Description'), ('attachment', 'Attachment')], max_length=255)), - ('count', models.IntegerField(default=0)), - ], - options={ - 'db_table': 'project_docket-entry-short_doc', - }, - ), - migrations.CreateModel( - name='Label', - fields=[ - ('id', models.BigAutoField(primary_key=True, serialize=False)), - ('created_at', models.DateTimeField(auto_now_add=True)), - ('updated_at', models.DateTimeField(auto_now=True)), - ('name', models.CharField(max_length=255)), - ('num_excluded', models.IntegerField(default=0)), - ('num_neutral', models.IntegerField(default=0)), - ('num_likely', models.IntegerField(default=0)), - ('instructions', models.TextField(blank=True, null=True)), - ('inference_model', models.CharField(choices=[('GPT-5 Mini', 'openai/gpt-5-mini'), ('GPT-5', 'openai/gpt-5'), ('Gemini 2.5 Flash Lite', 'gemini/gemini-2.5-flash-lite'), ('Gemini 2.5 Flash', 'gemini/gemini-2.5-flash'), ('Gemini 2.5 Pro', 'gemini/gemini-2.5-pro'), ('Qwen 235B-A22B', 'bedrock/qwen.qwen3-235b-a22b-2507-v1:0'), ('Claude Sonnet 4.5', 'bedrock/us.anthropic.claude-sonnet-4-5-20250929-v1:0')], default='openai/gpt-5-mini', max_length=255)), - ('teacher_model', models.CharField(choices=[('GPT-5 Mini', 'openai/gpt-5-mini'), ('GPT-5', 'openai/gpt-5'), ('Gemini 2.5 Flash Lite', 'gemini/gemini-2.5-flash-lite'), ('Gemini 2.5 Flash', 'gemini/gemini-2.5-flash'), ('Gemini 2.5 Pro', 'gemini/gemini-2.5-pro'), ('Qwen 235B-A22B', 'bedrock/qwen.qwen3-235b-a22b-2507-v1:0'), ('Claude Sonnet 4.5', 'bedrock/us.anthropic.claude-sonnet-4-5-20250929-v1:0')], default='openai/gpt-5', max_length=255)), - ('predictor_data', models.JSONField(blank=True, null=True)), - ('predictor_updated_at', models.DateTimeField(blank=True, null=True)), - ('trainset_examples_per_heuristic_bucket', models.IntegerField(default=1000)), - ('trainset_num_excluded', models.IntegerField(default=1000)), - ('trainset_num_neutral', models.IntegerField(default=1000)), - ('trainset_num_likely', models.IntegerField(default=1000)), - ('trainset_updated_at', models.DateTimeField(blank=True, null=True)), - ('trainset_predictions_updated_at', models.DateTimeField(blank=True, null=True)), - ('trainset_num_positive_preds', models.IntegerField(default=0)), - ('trainset_num_negative_preds', models.IntegerField(default=0)), - ], - ), - migrations.CreateModel( - name='LabelDecision', - fields=[ - ('id', models.BigAutoField(primary_key=True, serialize=False)), - ('created_at', models.DateTimeField(auto_now_add=True)), - ('updated_at', models.DateTimeField(auto_now=True)), - ('text_hash', models.CharField(max_length=255)), - ('text', models.TextField(blank=True, null=True)), - ('value', models.BooleanField()), - ('reason', models.TextField()), - ], - ), - migrations.CreateModel( - name='LabelHeuristic', - fields=[ - ('id', models.BigAutoField(primary_key=True, serialize=False)), - ('created_at', models.DateTimeField(auto_now_add=True)), - ('updated_at', models.DateTimeField(auto_now=True)), - ('querystring', models.TextField(blank=True, null=True)), - ('custom', models.CharField(blank=True, max_length=255, null=True)), - ('applied_at', models.DateTimeField(blank=True, null=True)), - ('is_minimal', models.BooleanField(default=False)), - ('is_likely', models.BooleanField(default=False)), - ('num_examples', models.IntegerField(default=0)), - ], - options={ - 'abstract': False, - }, - ), - migrations.CreateModel( - name='LabelTag', - fields=[ - ('id', models.BigAutoField(primary_key=True, serialize=False)), - ('created_at', models.DateTimeField(auto_now_add=True)), - ('updated_at', models.DateTimeField(auto_now=True)), - ('name', models.CharField(max_length=255)), - ('slug', models.CharField(max_length=255)), - ], - ), - migrations.CreateModel( - name='LabelTrainsetExample', - fields=[ - ('id', models.BigAutoField(primary_key=True, serialize=False)), - ('created_at', models.DateTimeField(auto_now_add=True)), - ('updated_at', models.DateTimeField(auto_now=True)), - ('text_hash', models.CharField(max_length=255)), - ('text', models.TextField(blank=True, null=True)), - ('split', models.CharField(choices=[('train', 'Train'), ('eval', 'Eval')], max_length=10)), - ('pred', models.BooleanField(blank=True, null=True)), - ('reason', models.TextField(blank=True, null=True)), - ], - ), migrations.CreateModel( name='Project', fields=[ + ('id', django_shortuuid.fields.ShortUUIDField(alphabet=None, blank=True, collision_check=True, editable=False, length=22, max_length=22, max_retries=10, prefix='', primary_key=True, serialize=False, unique=True)), ('created_at', models.DateTimeField(auto_now_add=True)), - ('updated_at', models.DateTimeField(auto_now=True)), - ('id', models.CharField(max_length=255, primary_key=True, serialize=False)), + ('updated_at', models.DateTimeField(default=django.utils.timezone.now)), ('name', models.CharField(max_length=255)), - ('model_name', models.CharField(max_length=255, unique=True)), - ('tags_model_name', models.CharField(blank=True, max_length=255, null=True)), - ('instructions', models.TextField(blank=True, null=True)), ], options={ 'abstract': False, }, ), migrations.CreateModel( - name='DocketEntryTags', - fields=[ - ('id', models.OneToOneField(db_column='id', on_delete=django.db.models.deletion.CASCADE, primary_key=True, related_name='example_tags', serialize=False, to='app.docketentry')), - ('tags', django.contrib.postgres.fields.ArrayField(base_field=models.BigIntegerField(), blank=True, default=list, size=None)), - ], - options={ - 'db_table': 'project_docket-entry_tags', - }, - ), - migrations.AddIndex( - model_name='docketentry', - index=models.Index(fields=['shuffle_sort', 'id'], name='docket-entry_s_idx'), - ), - migrations.AddIndex( - model_name='docketentry', - index=models.Index(fields=['text_prefix'], name='docket-entry_pr_idx', opclasses=['text_pattern_ops']), - ), - migrations.AddIndex( - model_name='docketentry', - index=django.contrib.postgres.indexes.GinIndex(fields=['text'], name='docket-entry_trg_idx', opclasses=['gin_trgm_ops']), - ), - migrations.AddIndex( - model_name='docketentry', - index=pgvector.django.indexes.HnswIndex(ef_construction=64, fields=['embedding'], m=16, name='docket-entry_hnsw_idx', opclasses=['halfvec_cosine_ops']), - ), - migrations.CreateModel( - name='DocketEntryShortTags', + name='Document', fields=[ - ('id', models.OneToOneField(db_column='id', on_delete=django.db.models.deletion.CASCADE, primary_key=True, related_name='example_tags', serialize=False, to='app.docketentryshort')), - ('tags', django.contrib.postgres.fields.ArrayField(base_field=models.BigIntegerField(), blank=True, default=list, size=None)), + ('id', django_shortuuid.fields.ShortUUIDField(alphabet=None, blank=True, collision_check=True, editable=False, length=22, max_length=22, max_retries=10, prefix='', primary_key=True, serialize=False, unique=True)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(default=django.utils.timezone.now)), + ('text', models.TextField()), + ('text_prefix', models.CharField(max_length=50)), + ('meta', models.JSONField(blank=True, default=dict, null=True)), + ('shuffle_key', models.IntegerField()), + ('text_hash', models.CharField(max_length=64)), + ('project', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='documents', to='app.project')), ], options={ - 'db_table': 'project_docket-entry-short_tags', + 'indexes': [models.Index(fields=['shuffle_key', 'id'], name='shuffle_key_idx'), models.Index(fields=['text_prefix'], name='text_prefix_idx', opclasses=['text_pattern_ops']), django.contrib.postgres.indexes.GinIndex(fields=['text'], name='text_trgm_idx', opclasses=['gin_trgm_ops'])], + 'constraints': [models.UniqueConstraint(fields=('project', 'text_hash'), name='document_project_text_hash_uniq')], }, ), - migrations.AddIndex( - model_name='docketentryshort', - index=models.Index(fields=['shuffle_sort', 'id'], name='docket-entry-short_s_idx'), - ), - migrations.AddIndex( - model_name='docketentryshort', - index=models.Index(fields=['text_prefix'], name='docket-entry-short_pr_idx', opclasses=['text_pattern_ops']), - ), - migrations.AddIndex( - model_name='docketentryshort', - index=django.contrib.postgres.indexes.GinIndex(fields=['text'], name='docket-entry-short_trg_idx', opclasses=['gin_trgm_ops']), - ), - migrations.AddIndex( - model_name='docketentryshort', - index=pgvector.django.indexes.HnswIndex(ef_construction=64, fields=['embedding'], m=16, name='docket-entry-short_hnsw_idx', opclasses=['halfvec_cosine_ops']), - ), - migrations.AddField( - model_name='labeldecision', - name='label', - field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='decisions', to='app.label'), - ), - migrations.AddField( - model_name='labelheuristic', - name='label', - field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='heuristics', to='app.label'), - ), - migrations.AddField( - model_name='labeltag', - name='heuristic', - field=models.OneToOneField(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='tag', to='app.labelheuristic'), - ), - migrations.AddField( - model_name='labeltag', - name='label', - field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='tags', to='app.label'), - ), - migrations.AddField( - model_name='labeltrainsetexample', - name='label', - field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='trainset_examples', to='app.label'), - ), - migrations.AddField( - model_name='label', - name='project', - field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, to='app.project'), - ), - migrations.AddIndex( - model_name='docketentrytags', - index=django.contrib.postgres.indexes.GinIndex(fields=['tags'], name='docket-entry_t_gin'), - ), - migrations.AddIndex( - model_name='docketentryshorttags', - index=django.contrib.postgres.indexes.GinIndex(fields=['tags'], name='docket-entry-short_t_gin'), - ), - migrations.AlterUniqueTogether( - name='labeldecision', - unique_together={('label', 'text_hash')}, - ), - migrations.AlterUniqueTogether( - name='labeltag', - unique_together={('name', 'label')}, - ), - migrations.AlterUniqueTogether( - name='labeltrainsetexample', - unique_together={('label', 'text_hash')}, - ), - migrations.AlterUniqueTogether( - name='label', - unique_together={('project', 'name')}, - ), ] diff --git a/clx/app/migrations/0002_label_project_active_label_and_more.py b/clx/app/migrations/0002_label_project_active_label_and_more.py new file mode 100644 index 0000000..2c6f6da --- /dev/null +++ b/clx/app/migrations/0002_label_project_active_label_and_more.py @@ -0,0 +1,35 @@ +# Generated by Django 5.2.7 on 2026-04-06 20:55 + +import django.db.models.deletion +import django.utils.timezone +import django_shortuuid.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0001_initial'), + ] + + operations = [ + migrations.CreateModel( + name='Label', + fields=[ + ('id', django_shortuuid.fields.ShortUUIDField(alphabet=None, blank=True, collision_check=True, editable=False, length=22, max_length=22, max_retries=10, prefix='', primary_key=True, serialize=False, unique=True)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(default=django.utils.timezone.now)), + ('name', models.CharField(max_length=255)), + ('project', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='labels', to='app.project')), + ], + ), + migrations.AddField( + model_name='project', + name='active_label', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='+', to='app.label'), + ), + migrations.AddConstraint( + model_name='label', + constraint=models.UniqueConstraint(fields=('project', 'name'), name='label_project_name_uniq'), + ), + ] diff --git a/clx/app/migrations/0002_label_trainset_num_decision_neighbors_and_more.py b/clx/app/migrations/0002_label_trainset_num_decision_neighbors_and_more.py deleted file mode 100644 index 6325bde..0000000 --- a/clx/app/migrations/0002_label_trainset_num_decision_neighbors_and_more.py +++ /dev/null @@ -1,24 +0,0 @@ -# Generated by Django 5.2.7 on 2025-12-24 02:42 - -import django.db.models.deletion -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ('app', '0001_initial'), - ] - - operations = [ - migrations.AddField( - model_name='label', - name='trainset_num_decision_neighbors', - field=models.IntegerField(default=50), - ), - migrations.AlterField( - model_name='label', - name='project', - field=models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='labels', to='app.project'), - ), - ] diff --git a/clx/app/migrations/0003_label_instructions_thread_message.py b/clx/app/migrations/0003_label_instructions_thread_message.py new file mode 100644 index 0000000..1e62469 --- /dev/null +++ b/clx/app/migrations/0003_label_instructions_thread_message.py @@ -0,0 +1,48 @@ +# Generated by Django 5.2.7 on 2026-04-06 22:03 + +import django.db.models.deletion +import django.utils.timezone +import django_shortuuid.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0002_label_project_active_label_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='label', + name='instructions', + field=models.TextField(blank=True, default=''), + ), + migrations.CreateModel( + name='Thread', + fields=[ + ('id', django_shortuuid.fields.ShortUUIDField(alphabet=None, blank=True, collision_check=True, editable=False, length=22, max_length=22, max_retries=10, prefix='', primary_key=True, serialize=False, unique=True)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(default=django.utils.timezone.now)), + ('model', models.CharField(max_length=255)), + ('label', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='threads', to='app.label')), + ], + options={ + 'abstract': False, + }, + ), + migrations.CreateModel( + name='Message', + fields=[ + ('id', django_shortuuid.fields.ShortUUIDField(alphabet=None, blank=True, collision_check=True, editable=False, length=22, max_length=22, max_retries=10, prefix='', primary_key=True, serialize=False, unique=True)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(default=django.utils.timezone.now)), + ('data', models.JSONField(default=dict)), + ('num_tokens', models.IntegerField(default=0)), + ('thread', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='messages', to='app.thread')), + ], + options={ + 'abstract': False, + }, + ), + ] diff --git a/clx/app/migrations/0003_labelfinetune.py b/clx/app/migrations/0003_labelfinetune.py deleted file mode 100644 index 6b996e4..0000000 --- a/clx/app/migrations/0003_labelfinetune.py +++ /dev/null @@ -1,28 +0,0 @@ -# Generated by Django 5.2.7 on 2025-12-29 12:59 - -import django.db.models.deletion -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ('app', '0002_label_trainset_num_decision_neighbors_and_more'), - ] - - operations = [ - migrations.CreateModel( - name='LabelFinetune', - fields=[ - ('id', models.BigAutoField(primary_key=True, serialize=False)), - ('created_at', models.DateTimeField(auto_now_add=True)), - ('updated_at', models.DateTimeField(auto_now=True)), - ('config_name', models.CharField(max_length=255)), - ('eval_results', models.JSONField(blank=True, null=True)), - ('label', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='fintunes', to='app.label')), - ], - options={ - 'abstract': False, - }, - ), - ] diff --git a/clx/app/migrations/0004_labelfinetune_finetuned_at_and_more.py b/clx/app/migrations/0004_labelfinetune_finetuned_at_and_more.py deleted file mode 100644 index c5e4b19..0000000 --- a/clx/app/migrations/0004_labelfinetune_finetuned_at_and_more.py +++ /dev/null @@ -1,23 +0,0 @@ -# Generated by Django 5.2.7 on 2026-01-15 18:50 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ('app', '0003_labelfinetune'), - ] - - operations = [ - migrations.AddField( - model_name='labelfinetune', - name='finetuned_at', - field=models.DateTimeField(blank=True, null=True), - ), - migrations.AddField( - model_name='labelfinetune', - name='predicted_at', - field=models.DateTimeField(blank=True, null=True), - ), - ] diff --git a/clx/app/migrations/0004_project_instructions_project_manual_instructions_and_more.py b/clx/app/migrations/0004_project_instructions_project_manual_instructions_and_more.py new file mode 100644 index 0000000..689cff1 --- /dev/null +++ b/clx/app/migrations/0004_project_instructions_project_manual_instructions_and_more.py @@ -0,0 +1,28 @@ +# Generated by Django 5.2.7 on 2026-04-06 22:42 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0003_label_instructions_thread_message'), + ] + + operations = [ + migrations.AddField( + model_name='project', + name='instructions', + field=models.TextField(blank=True, default=''), + ), + migrations.AddField( + model_name='project', + name='manual_instructions', + field=models.BooleanField(default=False), + ), + migrations.AlterField( + model_name='thread', + name='model', + field=models.CharField(default='gemini/gemini-3.1-pro-preview', max_length=255), + ), + ] diff --git a/clx/app/migrations/0005_remove_labeltrainsetexample_reason_and_more.py b/clx/app/migrations/0005_remove_labeltrainsetexample_reason_and_more.py deleted file mode 100644 index 73b31f7..0000000 --- a/clx/app/migrations/0005_remove_labeltrainsetexample_reason_and_more.py +++ /dev/null @@ -1,23 +0,0 @@ -# Generated by Django 5.2.7 on 2026-02-13 21:12 - -import django.db.models.deletion -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ('app', '0004_labelfinetune_finetuned_at_and_more'), - ] - - operations = [ - migrations.RemoveField( - model_name='labeltrainsetexample', - name='reason', - ), - migrations.AddField( - model_name='labeltrainsetexample', - name='decision', - field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='trainset_examples', to='app.labeldecision'), - ), - ] diff --git a/clx/app/migrations/0005_remove_project_manual_instructions.py b/clx/app/migrations/0005_remove_project_manual_instructions.py new file mode 100644 index 0000000..7735cdd --- /dev/null +++ b/clx/app/migrations/0005_remove_project_manual_instructions.py @@ -0,0 +1,17 @@ +# Generated by Django 5.2.7 on 2026-04-07 15:51 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0004_project_instructions_project_manual_instructions_and_more'), + ] + + operations = [ + migrations.RemoveField( + model_name='project', + name='manual_instructions', + ), + ] diff --git a/clx/app/migrations/0006_remove_labeltrainsetexample_decision_and_more.py b/clx/app/migrations/0006_remove_labeltrainsetexample_decision_and_more.py deleted file mode 100644 index e8dad74..0000000 --- a/clx/app/migrations/0006_remove_labeltrainsetexample_decision_and_more.py +++ /dev/null @@ -1,22 +0,0 @@ -# Generated by Django 5.2.7 on 2026-02-16 22:16 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ('app', '0005_remove_labeltrainsetexample_reason_and_more'), - ] - - operations = [ - migrations.RemoveField( - model_name='labeltrainsetexample', - name='decision', - ), - migrations.AddField( - model_name='labeltrainsetexample', - name='reason', - field=models.TextField(blank=True, null=True), - ), - ] diff --git a/clx/app/migrations/0006_thread_state.py b/clx/app/migrations/0006_thread_state.py new file mode 100644 index 0000000..31adf47 --- /dev/null +++ b/clx/app/migrations/0006_thread_state.py @@ -0,0 +1,18 @@ +# Generated by Django 5.2.7 on 2026-04-07 16:05 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0005_remove_project_manual_instructions'), + ] + + operations = [ + migrations.AddField( + model_name='thread', + name='state', + field=models.JSONField(blank=True, default=dict), + ), + ] diff --git a/clx/app/migrations/0007_labeldocument.py b/clx/app/migrations/0007_labeldocument.py new file mode 100644 index 0000000..41c14e2 --- /dev/null +++ b/clx/app/migrations/0007_labeldocument.py @@ -0,0 +1,29 @@ +# Generated by Django 5.2.7 on 2026-04-07 16:37 + +import django.db.models.deletion +import django.utils.timezone +import django_shortuuid.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0006_thread_state'), + ] + + operations = [ + migrations.CreateModel( + name='LabelDocument', + fields=[ + ('id', django_shortuuid.fields.ShortUUIDField(alphabet=None, blank=True, collision_check=True, editable=False, length=22, max_length=22, max_retries=10, prefix='', primary_key=True, serialize=False, unique=True)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(default=django.utils.timezone.now)), + ('document', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='label_documents', to='app.document')), + ('label', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='label_documents', to='app.label')), + ], + options={ + 'constraints': [models.UniqueConstraint(fields=('label', 'document'), name='labeldocument_label_document_uniq')], + }, + ), + ] diff --git a/clx/app/migrations/0007_remove_label_inference_model_and_more.py b/clx/app/migrations/0007_remove_label_inference_model_and_more.py deleted file mode 100644 index e57578e..0000000 --- a/clx/app/migrations/0007_remove_label_inference_model_and_more.py +++ /dev/null @@ -1,48 +0,0 @@ -# Generated by Django 5.2.7 on 2026-02-17 15:15 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ('app', '0006_remove_labeltrainsetexample_decision_and_more'), - ] - - operations = [ - migrations.RemoveField( - model_name='label', - name='inference_model', - ), - migrations.RemoveField( - model_name='label', - name='predictor_data', - ), - migrations.RemoveField( - model_name='label', - name='predictor_updated_at', - ), - migrations.RemoveField( - model_name='label', - name='teacher_model', - ), - migrations.RemoveField( - model_name='label', - name='trainset_examples_per_heuristic_bucket', - ), - migrations.AlterField( - model_name='label', - name='trainset_num_excluded', - field=models.IntegerField(default=50), - ), - migrations.AlterField( - model_name='label', - name='trainset_num_likely', - field=models.IntegerField(default=50), - ), - migrations.AlterField( - model_name='label', - name='trainset_num_neutral', - field=models.IntegerField(default=50), - ), - ] diff --git a/clx/app/migrations/0008_alter_label_trainset_num_decision_neighbors_and_more.py b/clx/app/migrations/0008_alter_label_trainset_num_decision_neighbors_and_more.py deleted file mode 100644 index eecf460..0000000 --- a/clx/app/migrations/0008_alter_label_trainset_num_decision_neighbors_and_more.py +++ /dev/null @@ -1,33 +0,0 @@ -# Generated by Django 5.2.7 on 2026-02-18 14:03 - -import django.db.models.deletion -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ('app', '0007_remove_label_inference_model_and_more'), - ] - - operations = [ - migrations.AlterField( - model_name='label', - name='trainset_num_decision_neighbors', - field=models.IntegerField(default=20), - ), - migrations.CreateModel( - name='LabelQuerystring', - fields=[ - ('id', models.BigAutoField(primary_key=True, serialize=False)), - ('created_at', models.DateTimeField(auto_now_add=True)), - ('updated_at', models.DateTimeField(auto_now=True)), - ('querystring', models.TextField()), - ('num_examples', models.IntegerField(default=30)), - ('label', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='querystrings', to='app.label')), - ], - options={ - 'abstract': False, - }, - ), - ] diff --git a/clx/app/migrations/0008_alter_thread_model.py b/clx/app/migrations/0008_alter_thread_model.py new file mode 100644 index 0000000..55061ef --- /dev/null +++ b/clx/app/migrations/0008_alter_thread_model.py @@ -0,0 +1,18 @@ +# Generated by Django 5.2.7 on 2026-04-07 16:51 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0007_labeldocument'), + ] + + operations = [ + migrations.AlterField( + model_name='thread', + name='model', + field=models.CharField(default='openai/gpt-5.4', max_length=255), + ), + ] diff --git a/clx/app/migrations/0009_alter_labelquerystring_num_examples_and_more.py b/clx/app/migrations/0009_alter_labelquerystring_num_examples_and_more.py deleted file mode 100644 index ac46d73..0000000 --- a/clx/app/migrations/0009_alter_labelquerystring_num_examples_and_more.py +++ /dev/null @@ -1,22 +0,0 @@ -# Generated by Django 5.2.7 on 2026-02-18 14:36 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ('app', '0008_alter_label_trainset_num_decision_neighbors_and_more'), - ] - - operations = [ - migrations.AlterField( - model_name='labelquerystring', - name='num_examples', - field=models.IntegerField(default=50), - ), - migrations.AlterUniqueTogether( - name='labelquerystring', - unique_together={('label', 'querystring')}, - ), - ] diff --git a/clx/app/migrations/0009_classificationannotation.py b/clx/app/migrations/0009_classificationannotation.py new file mode 100644 index 0000000..2bbf346 --- /dev/null +++ b/clx/app/migrations/0009_classificationannotation.py @@ -0,0 +1,30 @@ +# Generated by Django 5.2.7 on 2026-04-07 18:09 + +import django.db.models.deletion +import django.utils.timezone +import django_shortuuid.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0008_alter_thread_model'), + ] + + operations = [ + migrations.CreateModel( + name='ClassificationAnnotation', + fields=[ + ('id', django_shortuuid.fields.ShortUUIDField(alphabet=None, blank=True, collision_check=True, editable=False, length=22, max_length=22, max_retries=10, prefix='', primary_key=True, serialize=False, unique=True)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(default=django.utils.timezone.now)), + ('value', models.CharField(choices=[('yes', 'Yes'), ('no', 'No'), ('skip', 'Skip')], max_length=4)), + ('source', models.CharField(max_length=255)), + ('label_document', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='annotations', to='app.labeldocument')), + ], + options={ + 'constraints': [models.UniqueConstraint(fields=('label_document', 'source'), name='annotation_labeldoc_source_uniq')], + }, + ), + ] diff --git a/clx/app/migrations/0010_labeldecision_added_to_sample_and_more.py b/clx/app/migrations/0010_labeldecision_added_to_sample_and_more.py deleted file mode 100644 index 09b1f56..0000000 --- a/clx/app/migrations/0010_labeldecision_added_to_sample_and_more.py +++ /dev/null @@ -1,23 +0,0 @@ -# Generated by Django 5.2.7 on 2026-02-18 15:11 - -from django.db import migrations, models - - -class Migration(migrations.Migration): - - dependencies = [ - ('app', '0009_alter_labelquerystring_num_examples_and_more'), - ] - - operations = [ - migrations.AddField( - model_name='labeldecision', - name='added_to_sample', - field=models.BooleanField(default=False), - ), - migrations.AddField( - model_name='labelquerystring', - name='added_to_sample', - field=models.BooleanField(default=False), - ), - ] diff --git a/clx/app/migrations/0010_thread_total_cost_thread_total_tokens.py b/clx/app/migrations/0010_thread_total_cost_thread_total_tokens.py new file mode 100644 index 0000000..164eb0b --- /dev/null +++ b/clx/app/migrations/0010_thread_total_cost_thread_total_tokens.py @@ -0,0 +1,23 @@ +# Generated by Django 5.2.7 on 2026-04-07 18:53 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0009_classificationannotation'), + ] + + operations = [ + migrations.AddField( + model_name='thread', + name='total_cost', + field=models.FloatField(default=0.0), + ), + migrations.AddField( + model_name='thread', + name='total_tokens', + field=models.IntegerField(default=0), + ), + ] diff --git a/clx/app/migrations/0011_remove_thread_total_tokens.py b/clx/app/migrations/0011_remove_thread_total_tokens.py new file mode 100644 index 0000000..be624bb --- /dev/null +++ b/clx/app/migrations/0011_remove_thread_total_tokens.py @@ -0,0 +1,17 @@ +# Generated by Django 5.2.7 on 2026-04-07 23:01 + +from django.db import migrations + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0010_thread_total_cost_thread_total_tokens'), + ] + + operations = [ + migrations.RemoveField( + model_name='thread', + name='total_tokens', + ), + ] diff --git a/clx/app/migrations/0012_message_is_compact.py b/clx/app/migrations/0012_message_is_compact.py new file mode 100644 index 0000000..75f339c --- /dev/null +++ b/clx/app/migrations/0012_message_is_compact.py @@ -0,0 +1,18 @@ +# Generated by Django 5.2.7 on 2026-04-07 23:37 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0011_remove_thread_total_tokens'), + ] + + operations = [ + migrations.AddField( + model_name='message', + name='is_compact', + field=models.BooleanField(default=False), + ), + ] diff --git a/clx/app/migrations/0013_prompt.py b/clx/app/migrations/0013_prompt.py new file mode 100644 index 0000000..abbdb3a --- /dev/null +++ b/clx/app/migrations/0013_prompt.py @@ -0,0 +1,31 @@ +# Generated by Django 5.2.7 on 2026-04-08 12:33 + +import django.db.models.deletion +import django.utils.timezone +import django_shortuuid.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0012_message_is_compact'), + ] + + operations = [ + migrations.CreateModel( + name='Prompt', + fields=[ + ('id', django_shortuuid.fields.ShortUUIDField(alphabet=None, blank=True, collision_check=True, editable=False, length=22, max_length=22, max_retries=10, prefix='', primary_key=True, serialize=False, unique=True)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(default=django.utils.timezone.now)), + ('prompt_id', models.CharField(max_length=255)), + ('name', models.CharField(max_length=255)), + ('content', models.TextField(blank=True, default='')), + ('project', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='prompts', to='app.project')), + ], + options={ + 'constraints': [models.UniqueConstraint(fields=('project', 'prompt_id'), name='prompt_project_promptid_uniq')], + }, + ), + ] diff --git a/clx/app/migrations/0014_prompt_built_in.py b/clx/app/migrations/0014_prompt_built_in.py new file mode 100644 index 0000000..fb8fb56 --- /dev/null +++ b/clx/app/migrations/0014_prompt_built_in.py @@ -0,0 +1,18 @@ +# Generated by Django 5.2.7 on 2026-04-08 13:37 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0013_prompt'), + ] + + operations = [ + migrations.AddField( + model_name='prompt', + name='built_in', + field=models.BooleanField(default=False), + ), + ] diff --git a/clx/app/migrations/0015_task.py b/clx/app/migrations/0015_task.py new file mode 100644 index 0000000..3a946e3 --- /dev/null +++ b/clx/app/migrations/0015_task.py @@ -0,0 +1,31 @@ +# Generated by Django 5.2.7 on 2026-04-08 15:54 + +import django.db.models.deletion +import django.utils.timezone +import django_shortuuid.fields +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0014_prompt_built_in'), + ] + + operations = [ + migrations.CreateModel( + name='Task', + fields=[ + ('id', django_shortuuid.fields.ShortUUIDField(alphabet=None, blank=True, collision_check=True, editable=False, length=22, max_length=22, max_retries=10, prefix='', primary_key=True, serialize=False, unique=True)), + ('created_at', models.DateTimeField(auto_now_add=True)), + ('updated_at', models.DateTimeField(default=django.utils.timezone.now)), + ('prompt_id', models.CharField(max_length=255)), + ('status', models.CharField(choices=[('pending', 'Pending'), ('in_progress', 'In Progress')], default='pending', max_length=20)), + ('label', models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.CASCADE, related_name='tasks', to='app.label')), + ('project', models.ForeignKey(on_delete=django.db.models.deletion.CASCADE, related_name='tasks', to='app.project')), + ], + options={ + 'constraints': [models.UniqueConstraint(fields=('project', 'prompt_id', 'label'), name='task_project_prompt_label_uniq')], + }, + ), + ] diff --git a/clx/app/migrations/0016_label_autopilot_thread.py b/clx/app/migrations/0016_label_autopilot_thread.py new file mode 100644 index 0000000..b38b264 --- /dev/null +++ b/clx/app/migrations/0016_label_autopilot_thread.py @@ -0,0 +1,19 @@ +# Generated by Django 5.2.7 on 2026-04-08 15:59 + +import django.db.models.deletion +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0015_task'), + ] + + operations = [ + migrations.AddField( + model_name='label', + name='autopilot_thread', + field=models.ForeignKey(blank=True, null=True, on_delete=django.db.models.deletion.SET_NULL, related_name='+', to='app.thread'), + ), + ] diff --git a/clx/app/migrations/0017_project_autopilot_enabled.py b/clx/app/migrations/0017_project_autopilot_enabled.py new file mode 100644 index 0000000..0badb2c --- /dev/null +++ b/clx/app/migrations/0017_project_autopilot_enabled.py @@ -0,0 +1,18 @@ +# Generated by Django 5.2.7 on 2026-04-08 16:02 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0016_label_autopilot_thread'), + ] + + operations = [ + migrations.AddField( + model_name='project', + name='autopilot_enabled', + field=models.BooleanField(default=False), + ), + ] diff --git a/clx/app/migrations/0018_thread_autopilot_locked_alter_task_status.py b/clx/app/migrations/0018_thread_autopilot_locked_alter_task_status.py new file mode 100644 index 0000000..c62c046 --- /dev/null +++ b/clx/app/migrations/0018_thread_autopilot_locked_alter_task_status.py @@ -0,0 +1,23 @@ +# Generated by Django 5.2.7 on 2026-04-08 16:44 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0017_project_autopilot_enabled'), + ] + + operations = [ + migrations.AddField( + model_name='thread', + name='autopilot_locked', + field=models.BooleanField(default=False), + ), + migrations.AlterField( + model_name='task', + name='status', + field=models.CharField(choices=[('pending', 'Pending'), ('in_progress', 'In Progress'), ('awaiting_input', 'Awaiting Input')], default='pending', max_length=20), + ), + ] diff --git a/clx/app/migrations/0019_message_hidden.py b/clx/app/migrations/0019_message_hidden.py new file mode 100644 index 0000000..20b143f --- /dev/null +++ b/clx/app/migrations/0019_message_hidden.py @@ -0,0 +1,18 @@ +# Generated by Django 5.2.7 on 2026-04-08 22:15 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0018_thread_autopilot_locked_alter_task_status'), + ] + + operations = [ + migrations.AddField( + model_name='message', + name='hidden', + field=models.BooleanField(default=False), + ), + ] diff --git a/clx/app/migrations/0020_label_finetune_id_label_finetune_status_and_more.py b/clx/app/migrations/0020_label_finetune_id_label_finetune_status_and_more.py new file mode 100644 index 0000000..3dcb61a --- /dev/null +++ b/clx/app/migrations/0020_label_finetune_id_label_finetune_status_and_more.py @@ -0,0 +1,33 @@ +# Generated by Django 5.2.7 on 2026-04-09 17:54 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0019_message_hidden'), + ] + + operations = [ + migrations.AddField( + model_name='label', + name='finetune_id', + field=models.CharField(blank=True, default='', max_length=255), + ), + migrations.AddField( + model_name='label', + name='finetune_status', + field=models.CharField(blank=True, default='', max_length=20), + ), + migrations.AddField( + model_name='label', + name='finetune_training_args', + field=models.JSONField(blank=True, default=dict), + ), + migrations.AddField( + model_name='label', + name='finetuned_at', + field=models.DateTimeField(blank=True, null=True), + ), + ] diff --git a/clx/app/migrations/0021_label_predicted_at_labeldocument_prediction_and_more.py b/clx/app/migrations/0021_label_predicted_at_labeldocument_prediction_and_more.py new file mode 100644 index 0000000..d48cfa5 --- /dev/null +++ b/clx/app/migrations/0021_label_predicted_at_labeldocument_prediction_and_more.py @@ -0,0 +1,28 @@ +# Generated by Django 5.2.7 on 2026-04-09 18:31 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0020_label_finetune_id_label_finetune_status_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='label', + name='predicted_at', + field=models.DateTimeField(blank=True, null=True), + ), + migrations.AddField( + model_name='labeldocument', + name='prediction', + field=models.CharField(blank=True, choices=[('yes', 'yes'), ('no', 'no')], default='', max_length=3), + ), + migrations.AddField( + model_name='labeldocument', + name='prediction_confidence', + field=models.FloatField(blank=True, null=True), + ), + ] diff --git a/clx/app/migrations/0022_label_prediction_stats.py b/clx/app/migrations/0022_label_prediction_stats.py new file mode 100644 index 0000000..ab919d5 --- /dev/null +++ b/clx/app/migrations/0022_label_prediction_stats.py @@ -0,0 +1,18 @@ +# Generated by Django 5.2.7 on 2026-04-09 18:33 + +from django.db import migrations, models + + +class Migration(migrations.Migration): + + dependencies = [ + ('app', '0021_label_predicted_at_labeldocument_prediction_and_more'), + ] + + operations = [ + migrations.AddField( + model_name='label', + name='prediction_stats', + field=models.JSONField(blank=True, default=dict), + ), + ] diff --git a/clx/app/models.py b/clx/app/models.py index 5eefa1e..a314c6e 100644 --- a/clx/app/models.py +++ b/clx/app/models.py @@ -1,1060 +1,721 @@ -import time -from concurrent.futures import ThreadPoolExecutor, as_completed +import json +import random +from datetime import UTC, datetime +from io import StringIO +from pathlib import Path -import lmdb -import numpy as np import pandas as pd -import simplejson as json -from django.apps import apps +from django.conf import settings as django_settings +from django.contrib.postgres.indexes import GinIndex from django.db import models from django.utils import timezone +from django_shortuuid.fields import ShortUUIDField +from shortuuid import uuid from tqdm import tqdm -from clx import label2slug -from clx.llm import batch_embed, mesh_sort -from clx.llm.anno_agent import AnnoAgent -from clx.ml import pipeline, training_run from clx.settings import CLX_HOME -from clx.utils import pd_save_or_append +from clx.utils import generate_hash, pd_save_or_append -from .custom_heuristics import custom_heuristics -from .search_utils import BaseModel, SearchDocumentModel +from .search import SearchManager -class Project(BaseModel): - """Model for projects.""" - - id = models.CharField(max_length=255, primary_key=True) - name = models.CharField(max_length=255) - model_name = models.CharField(max_length=255, unique=True) - tags_model_name = models.CharField(max_length=255, null=True, blank=True) - instructions = models.TextField(null=True, blank=True) - - @property - def data_dir(self): - return CLX_HOME / "app_projects" / self.id - - @property - def cached_documents_path(self): - return self.data_dir / "docs.csv" +class Base(models.Model): + """Abstract base model for all CLX models.""" - @property - def cached_embeddings_path(self): - return self.data_dir / "embeddings.lmdb" - - def load_or_add_embeddings(self, data): - assert all(x in data.columns for x in ["text_hash", "text"]) - db = lmdb.open(str(self.cached_embeddings_path), map_size=1024**4) - with db.begin() as c: - data["embedding"] = data["text_hash"].apply( - lambda x: c.get(x.encode("utf-8")) - ) - data["embedding"] = data["embedding"].apply( - lambda x: json.loads(x) if x is not None else None - ) - needs_embeddings = data[data["embedding"].isna()] - data = data[data["embedding"].notna()] - needs_embeddings["embedding"] = batch_embed( - needs_embeddings["text"].tolist(), - num_workers=16, - dimensions=96, - ) - with db.begin(write=True) as c: - for row in needs_embeddings.to_dict("records"): - c.put( - row["text_hash"].encode("utf-8"), - json.dumps(row["embedding"]).encode("utf-8"), - ) - data = pd.concat([data, needs_embeddings]) - return data + id = ShortUUIDField(primary_key=True, editable=False) + created_at = models.DateTimeField(auto_now_add=True) + updated_at = models.DateTimeField(default=timezone.now) - @property - def cached_documents(self): - return pd.read_csv(self.cached_documents_path) + objects = SearchManager() - def get_search_model(self): - """Get the search model class for the project.""" - return apps.get_model("app", self.model_name) + class Meta: + abstract = True - def get_tags_model(self): - """Get the tags model class for the project.""" - return apps.get_model("app", self.tags_model_name) + def save(self, *args, **kwargs): + self.updated_at = timezone.now() + super().save(*args, **kwargs) -class Label(BaseModel): - """Model for labels.""" +class Project(Base): + """Model for projects.""" - project = models.ForeignKey( - Project, on_delete=models.CASCADE, related_name="labels" - ) name = models.CharField(max_length=255) - instructions = models.TextField(null=True, blank=True) - - # Sample counts - num_excluded = models.IntegerField(default=0) - num_neutral = models.IntegerField(default=0) - num_likely = models.IntegerField(default=0) - - # Trainset config - trainset_num_excluded = models.IntegerField(default=50) - trainset_num_neutral = models.IntegerField(default=50) - trainset_num_likely = models.IntegerField(default=50) - trainset_num_decision_neighbors = models.IntegerField(default=20) - trainset_updated_at = models.DateTimeField(null=True, blank=True) - trainset_predictions_updated_at = models.DateTimeField( - null=True, blank=True + instructions = models.TextField(blank=True, default="") + autopilot_enabled = models.BooleanField(default=False) + active_label = models.ForeignKey( + "Label", + on_delete=models.SET_NULL, + null=True, + blank=True, + related_name="+", ) - trainset_num_positive_preds = models.IntegerField(default=0) - trainset_num_negative_preds = models.IntegerField(default=0) - - @property - def data_dir(self): - return self.project.data_dir / "labels" / f"{label2slug(self.name)}" - - def excluded_query(self, queryset=None): - if queryset is None: - queryset = self.project.get_search_model().objects - tags = LabelTag.objects.filter(label=self, heuristic__is_minimal=True) - tag_ids = tags.values_list("id", flat=True) - if not tag_ids: - return queryset.none() - return queryset.tags(not_any=tag_ids) - - def neutral_query(self, queryset=None): - if queryset is None: - queryset = self.project.get_search_model().objects - minimal_tags = LabelTag.objects.filter( - label=self, heuristic__is_minimal=True - ) - minimal_tag_ids = minimal_tags.values_list("id", flat=True) - likely_tags = LabelTag.objects.filter( - label=self, heuristic__is_likely=True - ) - likely_tag_ids = likely_tags.values_list("id", flat=True) - return queryset.tags(any=minimal_tag_ids, not_any=likely_tag_ids) - - def likely_query(self, queryset=None): - if queryset is None: - queryset = self.project.get_search_model().objects - minimal_tags = LabelTag.objects.filter( - label=self, heuristic__is_minimal=True - ) - minimal_tag_ids = minimal_tags.values_list("id", flat=True) - likely_tags = LabelTag.objects.filter( - label=self, heuristic__is_likely=True - ) - likely_tag_ids = likely_tags.values_list("id", flat=True) - if not likely_tag_ids: - return queryset.none() - return queryset.tags(any=minimal_tag_ids).tags(any=likely_tag_ids) - - def get_minimal_fn(self): - minimal_fns = [ - x.heuristic.get_apply_fn() - for x in LabelTag.objects.filter( - label=self, heuristic__is_minimal=True - ) - ] - - def minimal_fn(text): - return any(f(text) for f in minimal_fns) - return minimal_fn + def add_docs(self, docs, **kwargs): + """Bulk-insert documents using django-postgres-copy. - def get_likely_fn(self): - likely_fns = [ - x.heuristic.get_apply_fn() - for x in LabelTag.objects.filter( - label=self, heuristic__is_likely=True - ) - ] + Args: + docs: Either a DataFrame with a 'text' column, a list of strings + (text only), or a list of dicts with 'text' and optionally + 'meta' keys. + """ + if docs is None: + return - def likely_fn(text): - return any(f(text) for f in likely_fns) - - return likely_fn - - def update_counts(self): - self.num_excluded = self.excluded_query().count() - self.num_likely = self.likely_query().count() - self.num_neutral = self.neutral_query().count() - self.save() - - def update_trainset(self): - data = self.load_trainset() - model = self.project.get_search_model() - - # Reset predictions for existing anno disagreements - needs_corrections = data[ - data["anno_value"].notna() - & data["pred"].notna() - & (data["anno_value"] != data["pred"]) - ]["text_hash"].tolist() - if len(needs_corrections): - LabelTrainsetExample.objects.filter( - label=self, text_hash__in=needs_corrections - ).update(pred=None, reason=None) - - new_ids = [] - - # Sample decision neighbors - model = self.project.get_search_model() - for decision in self.decisions.all(): - if not decision.added_to_sample: - embedding = ( - model.objects.filter(text_hash=decision.text_hash) - .first() - .embedding.to_list() - ) - decision_examples = model.objects.search( - semantic_sort=embedding, - page_size=self.trainset_num_decision_neighbors, + # Normalize input + if isinstance(docs, pd.DataFrame): + if docs.empty: + return + if "text" not in docs.columns: + raise ValueError( + "DataFrame input must include a 'text' column" ) - new_ids += [x["id"] for x in decision_examples["data"]] - decision.save(added_to_sample=True) - - # Sample on querystring samplers - for querystring in self.querystrings.all(): - if not querystring.added_to_sample: - querystring_examples = model.objects.search( - params={"querystring": querystring.querystring}, - page_size=querystring.num_examples, - sort=["shuffle_sort", "id"], - ) - new_ids += [x["id"] for x in querystring_examples["data"]] - querystring.save(added_to_sample=True) - - # Mesh sort helper - def apply_mesh_sort(queryset, n_examples): - """Select 10x the number of examples and take most diverse 10%""" - cluster_ks = [10, 10] - data = queryset.order_by("?").values("id", "embedding") - data = pd.DataFrame(data[: n_examples * 10]) - data["embedding"] = data["embedding"].apply(lambda x: x.to_list()) - data["sort"] = mesh_sort( - np.array(data["embedding"].tolist()), cluster_ks - ) - data = data.sort_values(by="sort").head(n_examples) - return data["id"].tolist() - # Sample from heuristic buckets - num_excluded = self.trainset_num_excluded - len( - data[data["bucket"] == "excluded"] - ) - num_neutral = self.trainset_num_neutral - len( - data[data["bucket"] == "neutral"] - ) - num_likely = self.trainset_num_likely - len( - data[data["bucket"] == "likely"] - ) + meta_columns = [col for col in docs.columns if col != "text"] + docs = [ + { + "text": row["text"], + "meta": { + key: value + for key, value in row[meta_columns].items() + if pd.notna(value) + }, + } + for _, row in docs.iterrows() + ] + elif not docs: + return + elif isinstance(docs[0], str): + docs = [{"text": t, "meta": {}} for t in docs] + else: + docs = [ + {"text": d["text"], "meta": d.get("meta", {})} for d in docs + ] - if num_excluded > 0: - new_ids += apply_mesh_sort(self.excluded_query(), num_excluded) - if num_neutral > 0: - new_ids += apply_mesh_sort(self.neutral_query(), num_neutral) - if num_likely > 0: - new_ids += apply_mesh_sort(self.likely_query(), num_likely) - - # Get new examples - cols = ["text", "text_hash"] - new_examples = pd.DataFrame( - model.objects.filter(id__in=new_ids).values(*cols), - columns=cols, - ) - new_examples = new_examples[ - ~new_examples["text_hash"].isin(data["text_hash"]) + # Build DataFrame + data = pd.DataFrame(docs) + data = data.dropna(subset=["text"]) + if len(data) == 0: + return + data["id"] = [uuid() for _ in range(len(data))] + data["text_prefix"] = data["text"].str[:50] + data["text_hash"] = data["text"].apply(generate_hash) + data["shuffle_key"] = [ + random.randint(0, 1_000_000) for _ in range(len(data)) ] - new_examples = new_examples.drop_duplicates(subset="text_hash") - new_examples = new_examples.sample(frac=1) - - # Make train/eval split - split = int(len(new_examples) * 0.8) - train_examples = new_examples.head(split) - train_examples["split"] = "train" - eval_examples = new_examples.tail(len(new_examples) - split) - eval_examples["split"] = "eval" - new_examples = pd.concat([train_examples, eval_examples]) - - new_examples = pd.concat([train_examples, eval_examples]) - - # Add to trainset - rows = new_examples.to_dict("records") - LabelTrainsetExample.objects.bulk_create( - [LabelTrainsetExample(label_id=self.id, **row) for row in rows], - batch_size=1000, - ) - self.sync_trainset_tags() - self.update_trainset_pred_counts() - self.trainset_updated_at = timezone.now() - self.save() - - def reset_trainset(self): - self.trainset_examples.all().delete() - self.decisions.all().update(added_to_sample=False) - self.querystrings.all().update(added_to_sample=False) - self.sync_trainset_tags() - self.update_trainset_pred_counts() - self.trainset_updated_at = None - self.trainset_predictions_updated_at = None - self.save() - - def load_annos(self): - project = self.project - search_model = project.get_search_model() - - pos_annos = search_model.objects.tags( - any=[self.anno_true_tag.id] - ).values("text_hash", "text") - pos_annos = pd.DataFrame(pos_annos) - pos_annos["value"] = True - - neg_annos = search_model.objects.tags( - any=[self.anno_false_tag.id] - ).values("text_hash", "text") - neg_annos = pd.DataFrame(neg_annos) - neg_annos["value"] = False - - flag_annos = search_model.objects.tags( - any=[self.anno_flag_tag.id] - ).values("text_hash", "text") - flag_annos = pd.DataFrame(flag_annos) - flag_annos["value"] = None - - annos = pd.concat([pos_annos, neg_annos, flag_annos]) - if annos.empty: - return pd.DataFrame(columns=["text_hash", "text", "value"]) - return annos - - def load_trainset(self): - trainset_hashes = list( - self.trainset_examples.values_list("text_hash", flat=True) + data["meta"] = data["meta"].apply(json.dumps) + + # Write to CSV buffer + f = StringIO() + data.to_csv(f, index=False) + f.seek(0) + + # Bulk insert + print("Pushing documents to database...") + Document.objects.from_csv( + f, + static_mapping={ + "project_id": str(self.id), + "created_at": timezone.now(), + "updated_at": timezone.now(), + }, + ignore_conflicts=True, + **kwargs, ) - annos = self.load_annos() - - missing_annos = annos[~annos["text_hash"].isin(trainset_hashes)] - missing_annos = missing_annos.drop_duplicates(subset="text_hash") - if len(missing_annos): - updates = [ - LabelTrainsetExample( - label_id=self.id, - text_hash=row["text_hash"], - text=row["text"], - split="train", - ) - for row in missing_annos.to_dict("records") - ] - LabelTrainsetExample.objects.bulk_create(updates, batch_size=1000) - cols = ["text_hash", "text", "split", "pred", "reason"] - data = pd.DataFrame( - self.trainset_examples.all().values(*cols), - columns=cols, - ) + @property + def project_dir(self): + """Return the project directory path.""" + return Path(CLX_HOME) / "projects" / str(self.id) - flagged_hashes = annos[annos["value"].isna()]["text_hash"].tolist() - annos = annos[~annos["value"].isna()] - annos = annos[["text_hash", "value"]].rename( - columns={"value": "anno_value"} - ) - data = data.merge(annos, on="text_hash", how="left") - - data["value"] = data["anno_value"].fillna(data["pred"]) - data.loc[data["text_hash"].isin(flagged_hashes), "value"] = None - - data = data.sample(frac=1, random_state=42) - data = data.reset_index(drop=True) - - minimal_fn = self.get_minimal_fn() - likely_fn = self.get_likely_fn() - data["bucket"] = data["text"].apply( - lambda x: "excluded" - if not minimal_fn(x) - else "likely" - if likely_fn(x) - else "neutral" - ) - return data - - def update_trainset_preds(self, num_threads=32): - data = self.load_trainset() - data = data[data["pred"].isna()] - texts = data["text"].tolist() - preds = self.batch_predict(texts, num_threads=num_threads) - data["pred"] = [x.get("value") for x in preds] - data["reason"] = [x.get("reason") for x in preds] - examples = self.trainset_examples.all() - examples = {e.text_hash: e for e in examples} - updates = [] - for row in data.to_dict("records"): - if row["text_hash"] in examples: - example = examples[row["text_hash"]] - example.pred = row["pred"] - example.reason = row["reason"] - updates.append(example) - LabelTrainsetExample.objects.bulk_update( - updates, - fields=["pred", "reason"], - batch_size=1000, - ) - self.update_trainset_pred_counts() - self.sync_trainset_pred_tags() - self.trainset_predictions_updated_at = timezone.now() - self.save() - - def update_trainset_pred_counts(self): - data = self.load_trainset() - if len(data) and "pred" in data.columns: - data = data.dropna(subset=["pred"]) - preds = data["pred"].astype(bool) - self.trainset_num_positive_preds = preds.sum() - self.trainset_num_negative_preds = (~preds).sum() - else: - self.trainset_num_positive_preds = 0 - self.trainset_num_negative_preds = 0 - self.save() - - def load_predictor(self): - args = { - "label_name": self.name, - "project_instructions": self.project.instructions, - "label_instructions": self.instructions, - "decisions": self.decisions.values("text", "value", "reason"), - } + def export_data(self, batch_size=100_000): + """Export documents to docs.csv in the project directory. - def predict_fn(text: str): - for _ in range(3): - try: - agent = AnnoAgent(**args) - anno = agent(text) - return { - "status": "success", - "value": anno.value, - "reason": anno.reason, - } - except Exception as e: - print(f"Error predicting {text}: {e}") - time.sleep(5) - return {"status": "error"} - - return predict_fn - - def batch_predict(self, texts: list[str], num_threads: int = 32): - predictor = self.load_predictor() - with ThreadPoolExecutor(max_workers=num_threads) as executor: - futures = [executor.submit(predictor, text) for text in texts] - for _ in tqdm( - as_completed(futures), total=len(futures), desc="Predicting" - ): + Only exports documents created since the last export. + """ + self.project_dir.mkdir(parents=True, exist_ok=True) + docs_path = self.project_dir / "docs.csv" + exported_path = self.project_dir / "exported.txt" + + # Check last export time. + last_export = None + if exported_path.exists(): + try: + last_export = datetime.fromisoformat( + exported_path.read_text().strip() + ) + except (ValueError, OSError): pass - return [future.result() for future in futures] - @property - def trainset_train_tag(self): - tag, _ = LabelTag.objects.get_or_create( - name="trainset:train", - label=self, - ) - return tag + qs = self.documents.order_by("created_at") + if last_export: + qs = qs.filter(created_at__gt=last_export) - @property - def trainset_eval_tag(self): - tag, _ = LabelTag.objects.get_or_create( - name="trainset:eval", - label=self, - ) - return tag + total = qs.count() + if total == 0: + return - @property - def trainset_pred_tag(self): - tag, _ = LabelTag.objects.get_or_create( - name="trainset:pred", - label=self, - ) - return tag + try: + for offset in tqdm( + range(0, total, batch_size), + desc="Exporting docs", + total=(total + batch_size - 1) // batch_size, + ): + batch = qs.values_list("id", "text")[ + offset : offset + batch_size + ] + data = pd.DataFrame( + list(batch), columns=["document_id", "text"] + ) + pd_save_or_append(data, docs_path) + except BaseException: + docs_path.unlink(missing_ok=True) + exported_path.unlink(missing_ok=True) + raise + + now = datetime.now(UTC).isoformat() + exported_path.write_text(now) + + def update_tasks(self): + """Sync tasks based on current project/label state. + + Rules: + - No project instructions → project_understanding task (no label) + - Has project instructions but label lacks instructions → label_understanding per label + - Both have instructions but label has no training examples → sampling_strategy per label + - Label has unannotated training examples → annotate per label + """ + from django.db.models import Count, Q - def get_trainset_finetune_tag(self, config_name): - tag, _ = LabelTag.objects.get_or_create( - name=f"trainset:ft:{config_name}", - label=self, - ) - return tag + expected = [] # list of (prompt_id, label_id | None) - @property - def finetune_tag(self): - tag, _ = LabelTag.objects.get_or_create( - name="ft", - label=self, - ) - return tag + if not self.instructions.strip(): + expected.append(("project_understanding", None)) + else: + labels = list(self.labels.all()) + for label in labels: + if not label.instructions.strip(): + expected.append(("label_understanding", label.id)) + else: + ld_stats = LabelDocument.objects.filter( + label=label + ).aggregate( + total=Count("id"), + annotated=Count( + "id", + filter=Q(annotations__source="agent"), + ), + ) + if ld_stats["total"] == 0: + expected.append(("sampling_strategy", label.id)) + elif ld_stats["annotated"] < ld_stats["total"]: + expected.append(("annotate", label.id)) + + expected_set = set(expected) + existing = {(t.prompt_id, t.label_id): t for t in self.tasks.all()} + + # Delete tasks no longer expected (keep in-progress and awaiting-input) + keep_statuses = (Task.Status.IN_PROGRESS, Task.Status.AWAITING_INPUT) + to_delete = [ + t.id + for key, t in existing.items() + if key not in expected_set and t.status not in keep_statuses + ] + if to_delete: + Task.objects.filter(id__in=to_delete).delete() + + # Create missing tasks + to_create = [ + Task(project=self, prompt_id=pid, label_id=lid) + for pid, lid in expected + if (pid, lid) not in existing + ] + if to_create: + Task.objects.bulk_create(to_create, ignore_conflicts=True) - @property - def anno_true_tag(self): - tag, _ = LabelTag.objects.get_or_create( - name="anno:true", - label=self, - ) - return tag + return list(self.tasks.select_related("label").order_by("created_at")) - @property - def anno_false_tag(self): - tag, _ = LabelTag.objects.get_or_create( - name="anno:false", - label=self, - ) - return tag - @property - def anno_flag_tag(self): - tag, _ = LabelTag.objects.get_or_create( - name="anno:flag", - label=self, - ) - return tag +class Document(Base): + """Model for documents within a project.""" - def sync_trainset_tags(self): - """Sync tags for train/eval splits to match current trainset examples.""" - model = self.project.get_search_model() + project = models.ForeignKey( + Project, on_delete=models.CASCADE, related_name="documents" + ) + text = models.TextField() + text_prefix = models.CharField(max_length=50) + meta = models.JSONField(default=dict, null=True, blank=True) + shuffle_key = models.IntegerField() + text_hash = models.CharField(max_length=64) - train_hashes = list( - self.trainset_examples.filter(split="train").values_list( - "text_hash", flat=True - ) - ) - if train_hashes: - train_ids = list( - model.objects.filter(text_hash__in=train_hashes).values_list( - "id", flat=True - ) + class Meta: + constraints = [ + models.UniqueConstraint( + fields=["project", "text_hash"], + name="document_project_text_hash_uniq", ) - else: - train_ids = [] - model.bulk_replace_tag(self.trainset_train_tag, train_ids) + ] + indexes = [ + models.Index( + fields=["shuffle_key", "id"], + name="shuffle_key_idx", + ), + models.Index( + fields=["text_prefix"], + name="text_prefix_idx", + opclasses=["text_pattern_ops"], + ), + GinIndex( + fields=["text"], + name="text_trgm_idx", + opclasses=["gin_trgm_ops"], + ), + ] - eval_hashes = list( - self.trainset_examples.filter(split="eval").values_list( - "text_hash", flat=True - ) - ) - if eval_hashes: - eval_ids = list( - model.objects.filter(text_hash__in=eval_hashes).values_list( - "id", flat=True - ) - ) - else: - eval_ids = [] - model.bulk_replace_tag(self.trainset_eval_tag, eval_ids) - - def sync_trainset_pred_tags(self): - """Sync tag for positive predictions to match current predicted positives.""" - model = self.project.get_search_model() - pos_hashes = list( - self.trainset_examples.filter(pred=True).values_list( - "text_hash", flat=True - ) - ) - if pos_hashes: - pos_ids = list( - model.objects.filter(text_hash__in=pos_hashes).values_list( - "id", flat=True - ) - ) - else: - pos_ids = [] - model.bulk_replace_tag(self.trainset_pred_tag, pos_ids) - - def get_finetune_run_name(self, config_name): - return f"{self.project_id}__{label2slug(self.name)}__{config_name}" - - def get_finetune_run_pipe(self, config_name): - run_name = self.get_finetune_run_name(config_name) - model_path = f"/runpod-volume/clx/runs/{run_name}/model" - return pipeline(task="classification", model=model_path, remote=True) - - def prepare_finetune( - self, config_name, batch_size=16, gradient_accumulation_steps=1 - ): - model = self.project.get_search_model() - config = model.finetune_configs[config_name] - data = self.load_trainset() - data = data.sample(frac=1, random_state=42) - data = ( - data[["text_hash", "text", "value", "split"]] - .rename(columns={"value": "label"}) - .dropna() - ) - data["label"] = data["label"].apply(lambda x: "yes" if x else "no") - train_data = data[data["split"] == "train"] - eval_data = data[data["split"] == "eval"] - - num_train_epochs = config["training_args"].get("num_train_epochs", 1) - config["training_args"]["num_train_epochs"] = num_train_epochs - total_steps = (num_train_epochs * len(train_data)) // ( - batch_size * gradient_accumulation_steps - ) - save_steps = total_steps // 9 - config["training_args"]["eval_strategy"] = "steps" - config["training_args"]["save_strategy"] = "steps" - config["training_args"]["eval_steps"] = save_steps - config["training_args"]["save_steps"] = save_steps - config["training_args"]["per_device_train_batch_size"] = batch_size - config["training_args"]["per_device_eval_batch_size"] = batch_size - config["training_args"]["gradient_accumulation_steps"] = ( - gradient_accumulation_steps - ) - run_config = { - "task": "classification", - "run_name": self.get_finetune_run_name(config_name), - "label_names": ["yes", "no"], - **config, - } +class Label(Base): + """Model for labels within a project.""" - return train_data, eval_data, run_config + project = models.ForeignKey( + Project, on_delete=models.CASCADE, related_name="labels" + ) + name = models.CharField(max_length=255) + instructions = models.TextField(blank=True, default="") + autopilot_thread = models.ForeignKey( + "Thread", + on_delete=models.SET_NULL, + null=True, + blank=True, + related_name="+", + ) + finetune_id = models.CharField(max_length=255, blank=True, default="") + finetune_training_args = models.JSONField(default=dict, blank=True) + finetuned_at = models.DateTimeField(null=True, blank=True) + finetune_status = models.CharField(max_length=20, blank=True, default="") + predicted_at = models.DateTimeField(null=True, blank=True) + prediction_stats = models.JSONField(default=dict, blank=True) - def train_finetune(self, config_name): - """Train a finetune model for this label.""" - train_data, eval_data, run_config = self.prepare_finetune(config_name) + class Meta: + constraints = [ + models.UniqueConstraint( + fields=["project", "name"], + name="label_project_name_uniq", + ) + ] - run = training_run(**run_config) - outputs = run.train(train_data, eval_data, overwrite=True, remote=True) + def finetune(self, training_args=None): + """Kick off a remote finetuning job for this label.""" + import os + import random - data = pd.concat([train_data, eval_data]) + import pandas as pd + import requests - pipe = self.get_finetune_run_pipe(config_name) - data["pred"] = pipe(data["text"].tolist(), batch_size=16) - data = data[data["pred"] == "yes"] + from clx.utils import S3 - tag = self.get_trainset_finetune_tag(config_name) - model = self.project.get_search_model() - example_ids = model.objects.filter( - text_hash__in=data["text_hash"].tolist() + training_args = training_args or {} + self.finetune_training_args = training_args + self.finetune_status = "pending" + self.save( + update_fields=[ + "finetune_training_args", + "finetune_status", + "updated_at", + ] ) - example_ids = example_ids.values_list("id", flat=True) - model.bulk_replace_tag(tag.id, example_ids) - finetune, _ = LabelFinetune.objects.get_or_create( - label=self, config_name=config_name + # Assemble data from annotated label documents (yes/no only). + # Single query: join through to document text and annotation value. + rows = list( + LabelDocument.objects.filter( + label=self, + annotations__source="agent", + annotations__value__in=["yes", "no"], + ).values_list( + "document__text", + "annotations__value", + ) ) - finetune.eval_results = outputs["results"] - finetune.finetuned_at = timezone.now() - finetune.save() + rows = [{"text": text, "label": value} for text, value in rows] + + random.shuffle(rows) + df = pd.DataFrame(rows) + split = max(1, int(len(df) * 0.2)) + eval_data = df.iloc[:split] + train_data = df.iloc[split:] + + # Build training args with sensible defaults. + import math + + from clx.ml import training_run + + batch_size = training_args.get("per_device_train_batch_size", 8) + grad_accum = training_args.get("gradient_accumulation_steps", 1) + effective_batch = batch_size * grad_accum + steps_per_epoch = max(1, math.ceil(len(train_data) / effective_batch)) + checkpoint_steps = max(1, steps_per_epoch // 5) + + defaults = { + "per_device_train_batch_size": batch_size, + "per_device_eval_batch_size": batch_size, + "gradient_accumulation_steps": grad_accum, + "num_train_epochs": 3, + "eval_strategy": "steps", + "eval_steps": checkpoint_steps, + "save_steps": checkpoint_steps, + } + merged_args = {**defaults, **training_args} - return finetune + run = training_run( + task="classification", + run_name=str(self.id), + label_names=["yes", "no"], + training_args=merged_args, + ) - def predict_finetune(self, batch_size=16, num_workers=64, force=False): - """Run finetune predictions across the entire corpus.""" - cache_path = self.data_dir / "finetune_predictions_cache.csv" - self.data_dir.mkdir(parents=True, exist_ok=True) - config_name = self.project.get_search_model().main_finetune_config - if config_name is None: - raise ValueError("Set main_finetune_config for this project") + # Upload data to S3 and submit to RunPod. + import tempfile + import uuid as _uuid + from pathlib import Path - if force and cache_path.exists(): - cache_path.unlink() + endpoint_id = os.getenv("RUNPOD_FINETUNE_ENDPOINT_ID") + api_key = os.getenv("RUNPOD_API_KEY") + if not endpoint_id or not api_key: + self.finetune_status = "error" + self.save(update_fields=["finetune_status", "updated_at"]) + raise ValueError( + "RUNPOD_FINETUNE_ENDPOINT_ID and RUNPOD_API_KEY must be set" + ) - cached_ids = set() - if cache_path.exists(): - cached_data = pd.read_csv(cache_path) - cached_ids = set(cached_data["id"].unique().tolist()) + s3 = S3() + job_key = str(_uuid.uuid4()) + s3_prefix = f"runpod/finetune/{job_key}" + + with tempfile.TemporaryDirectory() as tmpdir: + train_path = Path(tmpdir) / "train.csv" + train_data.to_csv(train_path, index=False) + s3.upload(train_path, f"{s3_prefix}/train.csv") + eval_path = Path(tmpdir) / "eval.csv" + eval_data.to_csv(eval_path, index=False) + s3.upload(eval_path, f"{s3_prefix}/eval.csv") + + config = run.config + del config["run_dir_parent"] + payload = { + "input": { + "training_run": config, + "s3_bucket": s3.bucket, + "s3_prefix": s3_prefix, + "overwrite": True, + } + } - model = self.project.get_search_model() - pipe = self.get_finetune_run_pipe(config_name) + response = requests.post( + f"https://api.runpod.ai/v2/{endpoint_id}/run", + headers={"Authorization": f"Bearer {api_key}"}, + json=payload, + ) + response.raise_for_status() + job_id = response.json()["id"] + + self.finetune_id = job_id + self.finetune_status = "in_progress" + self.save( + update_fields=[ + "finetune_id", + "finetune_status", + "updated_at", + ] + ) + return job_id + + @property + def pipe(self): + """Remote classification pipeline for the finetuned model.""" + from clx.ml import pipeline - minimal_heuristics = LabelHeuristic.objects.filter( - is_minimal=True, label=self + model_path = f"/runpod-volume/clx/runs/{self.id}/model" + return pipeline( + task="classification", + model=model_path, + remote=True, ) - minimal_conditions = [h.get_apply_fn() for h in minimal_heuristics] - - def minimal_condition_fn(text): - return any(condition(text) for condition in minimal_conditions) - - total_examples = model.objects.count() - outer_batch_size = 1024 * 500 - for batch in tqdm( - model.objects.batch_df("id", "text", batch_size=outer_batch_size), - desc=f"Predicting {config_name}", - total=total_examples // outer_batch_size, - ): - batch = batch[~batch["id"].isin(cached_ids)] - batch = batch[batch["text"].apply(minimal_condition_fn)] - if len(batch) > 0: - batch["value"] = pipe( - batch["text"].tolist(), - batch_size=batch_size, - num_workers=num_workers, - max_length=768, - truncation=True, - ) - batch["value"] = batch["value"].apply(lambda x: x == "yes") - pd_save_or_append(batch[["id", "value"]], cache_path) - - if cache_path.exists(): - all_preds = pd.read_csv(cache_path) - positive_ids = all_preds[all_preds["value"]]["id"].tolist() - tag = self.finetune_tag - model.bulk_replace_tag(tag.id, positive_ids) - finetune = self.fintunes.filter(config_name=config_name).first() - if finetune: - finetune.predicted_at = timezone.now() - finetune.save() - cache_path.unlink() - - print( - f"Predictions complete: {len(positive_ids):,} positive out of {len(all_preds):,} total" - ) - def update_all(self, num_threads=128, predict=False, force=False): - """Update all components that are out of date based on timestamps. + def predict(self): + """Run predictions across all label documents using the finetuned model.""" + from django.utils import timezone - Runs the full pipeline in order, but only steps that need updating: - 1. Resample trainset (if decisions newer than trainset) - 2. Run predictions (if trainset newer than predictions) - 3. Train finetunes (if predictions newer than finetunes) - 4. Run global corpus predictions (if predict is True and finetune newer than global predictions) - """ - missing = [] - if not self.heuristics.filter(is_minimal=True).exists(): - missing.append("at least one minimal heuristic") - if not self.heuristics.filter(is_likely=True).exists(): - missing.append("at least one likely heuristic") - if not self.decisions.filter(value=True).exists(): - missing.append("at least one positive decision") - if not self.decisions.filter(value=False).exists(): - missing.append("at least one negative decision") - - if missing: - print("Cannot run update_all - missing required setup:") - for item in missing: - print(f" - {item}") - return + if self.finetune_status != "completed": + raise ValueError("No completed finetune available.") - model = self.project.get_search_model() - finetune_configs = list(model.finetune_configs.keys()) + ld_qs = LabelDocument.objects.filter(label=self).select_related( + "document" + ) + ld_list = list(ld_qs.values_list("id", "document__text")) + if not ld_list: + return - # Get latest decision timestamp - latest_decision = self.decisions.order_by("-updated_at").first() - latest_decision_at = ( - latest_decision.updated_at if latest_decision else None + ld_ids = [row[0] for row in ld_list] + texts = [row[1] for row in ld_list] + + results = self.pipe.predict(texts, batch_size=16, return_scores=True) + + # Build predictions map. + predictions = {} + for ld_id, scores in zip(ld_ids, results): + yes_score = scores.get("yes", 0) + no_score = scores.get("no", 0) + pred = "yes" if yes_score >= no_score else "no" + top_score = max(yes_score, no_score) + confidence = abs(top_score - 0.5) * 2 + predictions[ld_id] = (pred, confidence) + + # Bulk update predictions. + ld_objs = [] + for ld_id, (pred, confidence) in predictions.items(): + obj = LabelDocument(id=ld_id) + obj.prediction = pred + obj.prediction_confidence = confidence + ld_objs.append(obj) + + LabelDocument.objects.bulk_update( + ld_objs, + ["prediction", "prediction_confidence"], + batch_size=1000, ) - # Step 1: Resample trainset if decisions are newer - if force or ( - latest_decision_at - and ( - not self.trainset_updated_at - or latest_decision_at > self.trainset_updated_at + # Compute F1 and accuracy on annotated examples (yes/no only). + annotated = dict( + ClassificationAnnotation.objects.filter( + label_document_id__in=ld_ids, + source="agent", + value__in=["yes", "no"], + ).values_list("label_document_id", "value") + ) + + if annotated: + tp = fp = fn = correct = 0 + total = len(annotated) + for ld_id, true_val in annotated.items(): + pred_val = predictions[ld_id][0] + if pred_val == true_val: + correct += 1 + if pred_val == "yes" and true_val == "yes": + tp += 1 + elif pred_val == "yes" and true_val == "no": + fp += 1 + elif pred_val == "no" and true_val == "yes": + fn += 1 + + precision = tp / (tp + fp) if (tp + fp) > 0 else 0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0 + f1 = ( + 2 * precision * recall / (precision + recall) + if (precision + recall) > 0 + else 0 ) - ): - print("Step 1: Resampling trainset") - self.update_trainset() - self.refresh_from_db() - - # Step 2: Run predictions if trainset is newer - if force or ( - self.trainset_updated_at - and ( - not self.trainset_predictions_updated_at - or self.trainset_updated_at - > self.trainset_predictions_updated_at + self.prediction_stats = { + "f1": round(f1, 4), + "accuracy": round(correct / total, 4), + "precision": round(precision, 4), + "recall": round(recall, 4), + "total": total, + } + else: + self.prediction_stats = {} + + self.predicted_at = timezone.now() + self.save( + update_fields=[ + "predicted_at", + "prediction_stats", + "updated_at", + ] + ) + + def recalculate_prediction_stats(self): + """Recompute F1/accuracy from existing predictions vs annotations.""" + ld_data = list( + LabelDocument.objects.filter( + label=self, ) - ): - print("Step 2: Running predictions") - self.update_trainset_preds(num_threads=num_threads) - self.refresh_from_db() - - # Step 3: Train finetunes if predictions are newer - for config_name in finetune_configs: - finetune = self.fintunes.filter(config_name=config_name).first() - finetuned_at = finetune.finetuned_at if finetune else None - - if force or ( - self.trainset_predictions_updated_at - and ( - not finetuned_at - or self.trainset_predictions_updated_at > finetuned_at - ) - ): - print(f"Step 3: Training finetune: {config_name}") - self.train_finetune(config_name) - - # Step 4: Run global corpus predictions if finetune is newer - if predict: - ft = self.fintunes.filter( - config_name=self.project.get_search_model().main_finetune_config - ).first() - if ft and ( - force - or ( - ft.finetuned_at - and ( - not ft.predicted_at - or ft.finetuned_at > ft.predicted_at - ) - ) - ): - print("Step 4: Running global predictions") - self.predict_finetune(force=force) + .exclude(prediction="") + .filter(prediction__isnull=False) + .values_list("id", "prediction") + ) + if not ld_data: + self.prediction_stats = {} + self.save(update_fields=["prediction_stats", "updated_at"]) + return - print("Update complete!") + predictions = dict(ld_data) + ld_ids = list(predictions.keys()) + + annotated = dict( + ClassificationAnnotation.objects.filter( + label_document_id__in=ld_ids, + source="agent", + value__in=["yes", "no"], + ).values_list("label_document_id", "value") + ) + + if annotated: + tp = fp = fn = correct = 0 + total = len(annotated) + for ld_id, true_val in annotated.items(): + pred_val = predictions.get(ld_id) + if not pred_val: + continue + if pred_val == true_val: + correct += 1 + if pred_val == "yes" and true_val == "yes": + tp += 1 + elif pred_val == "yes" and true_val == "no": + fp += 1 + elif pred_val == "no" and true_val == "yes": + fn += 1 + + precision = tp / (tp + fp) if (tp + fp) > 0 else 0 + recall = tp / (tp + fn) if (tp + fn) > 0 else 0 + f1 = ( + 2 * precision * recall / (precision + recall) + if (precision + recall) > 0 + else 0 + ) + self.prediction_stats = { + "f1": round(f1, 4), + "accuracy": round(correct / total, 4), + "precision": round(precision, 4), + "recall": round(recall, 4), + "total": total, + } + else: + self.prediction_stats = {} - class Meta: - unique_together = ("project", "name") + self.save(update_fields=["prediction_stats", "updated_at"]) -class LabelTag(BaseModel): - """Model for label tags.""" +class Prompt(Base): + """A customizable prompt template for a project.""" - name = models.CharField(max_length=255) - label = models.ForeignKey( - Label, on_delete=models.CASCADE, related_name="tags" - ) - slug = models.CharField(max_length=255) - heuristic = models.OneToOneField( - "LabelHeuristic", - on_delete=models.CASCADE, - null=True, - blank=True, - related_name="tag", + project = models.ForeignKey( + Project, on_delete=models.CASCADE, related_name="prompts" ) - - def save(self, *args, **kwargs): - self.slug = label2slug(self.name) + ":" + label2slug(self.label.name) - super().save(*args, **kwargs) + prompt_id = models.CharField(max_length=255) + name = models.CharField(max_length=255) + content = models.TextField(blank=True, default="") + built_in = models.BooleanField(default=False) class Meta: - unique_together = ("name", "label") + constraints = [ + models.UniqueConstraint( + fields=["project", "prompt_id"], + name="prompt_project_promptid_uniq", + ) + ] -class LabelDecision(BaseModel): - """Model for label decision boundaries.""" +class Thread(Base): + """Model for LLM threads tied to a label.""" label = models.ForeignKey( - Label, on_delete=models.CASCADE, related_name="decisions" + Label, on_delete=models.CASCADE, related_name="threads" ) - text_hash = models.CharField(max_length=255) - text = models.TextField(null=True, blank=True) - value = models.BooleanField() - reason = models.TextField() - added_to_sample = models.BooleanField(default=False) - - def save(self, *args, added_to_sample=False, **kwargs): - self.added_to_sample = added_to_sample - super().save(*args, **kwargs) - - class Meta: - unique_together = ("label", "text_hash") + model = models.CharField( + max_length=255, default=django_settings.DEFAULT_MODEL + ) + state = models.JSONField(default=dict, blank=True) + total_cost = models.FloatField(default=0.0) + autopilot_locked = models.BooleanField(default=False) -class LabelQuerystring(BaseModel): - """Model for label querystrings.""" +class LabelDocument(Base): + """Links a document to a label (e.g. as a training example).""" label = models.ForeignKey( - Label, on_delete=models.CASCADE, related_name="querystrings" + Label, on_delete=models.CASCADE, related_name="label_documents" ) - querystring = models.TextField() - num_examples = models.IntegerField(default=50) - added_to_sample = models.BooleanField(default=False) - - def save(self, *args, added_to_sample=False, **kwargs): - self.added_to_sample = added_to_sample - super().save(*args, **kwargs) + document = models.ForeignKey( + "Document", on_delete=models.CASCADE, related_name="label_documents" + ) + prediction = models.CharField( + max_length=3, + blank=True, + default="", + choices=[("yes", "yes"), ("no", "no")], + ) + prediction_confidence = models.FloatField(null=True, blank=True) class Meta: - unique_together = ("label", "querystring") + constraints = [ + models.UniqueConstraint( + fields=["label", "document"], + name="labeldocument_label_document_uniq", + ) + ] -class LabelHeuristic(BaseModel): - """Model for label heuristics.""" +class ClassificationAnnotation(Base): + """An annotation on a label document.""" - label = models.ForeignKey( - Label, on_delete=models.CASCADE, related_name="heuristics" + class Value(models.TextChoices): + YES = "yes" + NO = "no" + SKIP = "skip" + + label_document = models.ForeignKey( + LabelDocument, on_delete=models.CASCADE, related_name="annotations" ) - querystring = models.TextField(null=True, blank=True) - custom = models.CharField(max_length=255, null=True, blank=True) - applied_at = models.DateTimeField(null=True, blank=True) - is_minimal = models.BooleanField(default=False) - is_likely = models.BooleanField(default=False) - num_examples = models.IntegerField(default=0) + value = models.CharField(max_length=4, choices=Value.choices) + source = models.CharField(max_length=255) - def save(self, *args, **kwargs): - if sum([bool(self.querystring), bool(self.custom)]) != 1: - raise ValueError( - "Exactly one of querystring or custom must be provided." + class Meta: + constraints = [ + models.UniqueConstraint( + fields=["label_document", "source"], + name="annotation_labeldoc_source_uniq", ) - super().save(*args, **kwargs) - if self.applied_at is not None: - self.label.update_counts() - - def delete(self, *args, **kwargs): - self.is_minimal = False - self.is_likely = False - self.save() - self.label.update_counts() - super().delete(*args, **kwargs) - - @property - def name(self): - if self.querystring is not None: - return f"h:qs:{self.querystring}" - elif self.custom is not None: - return f"h:fn:{self.custom}" - - @classmethod - def sync_custom_heuristics(cls): - for heuristic in cls.objects.filter(custom__isnull=False): - label = heuristic.label - if ( - heuristic.custom not in custom_heuristics - or label.name - != custom_heuristics[heuristic.custom]["label_name"] - or label.project_id - != custom_heuristics[heuristic.custom]["project_id"] - ): - heuristic.delete() - - for custom_name, custom_heuristic in custom_heuristics.items(): - heuristic_exists = cls.objects.filter( - label__name=custom_heuristic["label_name"], - label__project_id=custom_heuristic["project_id"], - custom=custom_name, - ).exists() - if not heuristic_exists: - label, _ = Label.objects.get_or_create( - name=custom_heuristic["label_name"], - project_id=custom_heuristic["project_id"], - ) - heuristic = cls.objects.create( - label=label, - custom=custom_name, - ) + ] - def get_apply_fn(self, **kwargs): - def apply_fn(text): - if self.querystring is not None: - text = text.lower() - querystring = self.querystring.lower() - - for and_part in querystring.split(","): - and_part = and_part.strip() - meets_any_or = False - for or_part in and_part.split("|"): - or_part = or_part.strip() - negated = False - if or_part.startswith("~"): - or_part = or_part[1:].strip() - negated = True - if or_part.startswith("^"): - or_part = or_part[1:].strip() - if text.startswith(or_part.strip()) != negated: - meets_any_or = True - elif (or_part.strip() in text) != negated: - meets_any_or = True - if not meets_any_or: - return False - return True - elif self.custom is not None: - return custom_heuristics[self.custom]["apply_fn"]( - text, **kwargs - ) - return apply_fn +class Task(Base): + """A pending task for a project (e.g. 'annotate label X').""" - def apply(self): - tag, _ = LabelTag.objects.get_or_create( - name=self.name, label=self.label, heuristic=self - ) - apply_fn = self.get_apply_fn() - example_ids = [] - model = self.label.project.get_search_model() - batch_size = 1000000 - batches = model.objects.batch_df("id", "text", batch_size=batch_size) - for batch in tqdm( - batches, - desc="Applying heuristic", - total=model.objects.count() // batch_size, - ): - batch = batch[batch["text"].apply(apply_fn)] - example_ids.extend(batch["id"].tolist()) - model.bulk_replace_tag(tag.id, example_ids) - self.applied_at = timezone.now() - self.num_examples = model.objects.tags(any=[tag.id]).count() - self.save() - self.label.update_counts() - - -class LabelTrainsetExample(BaseModel): - """Model for label trainset examples.""" + class Status(models.TextChoices): + PENDING = "pending" + IN_PROGRESS = "in_progress" + AWAITING_INPUT = "awaiting_input" + project = models.ForeignKey( + Project, on_delete=models.CASCADE, related_name="tasks" + ) + prompt_id = models.CharField(max_length=255) label = models.ForeignKey( - Label, on_delete=models.CASCADE, related_name="trainset_examples" + Label, + on_delete=models.CASCADE, + null=True, + blank=True, + related_name="tasks", ) - text_hash = models.CharField(max_length=255) - text = models.TextField(null=True, blank=True) - split = models.CharField( - max_length=10, - choices=[("train", "Train"), ("eval", "Eval")], + status = models.CharField( + max_length=20, choices=Status.choices, default=Status.PENDING ) - pred = models.BooleanField(null=True, blank=True) - reason = models.TextField(null=True, blank=True) class Meta: - unique_together = ("label", "text_hash") - - -class LabelFinetune(BaseModel): - """Model for single-label finetuned models.""" - - label = models.ForeignKey( - Label, on_delete=models.CASCADE, related_name="fintunes" - ) - config_name = models.CharField(max_length=255) - eval_results = models.JSONField(null=True, blank=True) - finetuned_at = models.DateTimeField(null=True, blank=True) - predicted_at = models.DateTimeField(null=True, blank=True) - - -class DocketEntry(SearchDocumentModel): - """Docket entry model for main document entries.""" - - project_id = "docket-entry" - finetune_configs = { - "main": { - "base_model_name": "answerdotai/ModernBERT-base", - "training_args": { - "num_train_epochs": 10, - "learning_rate": 5e-5, - "warmup_ratio": 0.05, - "bf16": True, - }, - }, - } - main_finetune_config = "main" - - id = models.BigIntegerField(primary_key=True) - recap_id = models.BigIntegerField(unique=True) - docket_id = models.BigIntegerField() - entry_number = models.BigIntegerField(null=True, blank=True) - date_filed = models.DateField(null=True, blank=True) - - -DocketEntry.create_tags_model() - + constraints = [ + models.UniqueConstraint( + fields=["project", "prompt_id", "label"], + name="task_project_prompt_label_uniq", + ) + ] -class DocketEntryShort(SearchDocumentModel): - """Model for attachments and docket entry short descriptions.""" - project_id = "docket-entry-short" +class Message(Base): + """Model for messages within a thread.""" - text = models.TextField(unique=True) - text_type = models.CharField( - max_length=255, - choices=[ - ("short_description", "Short Description"), - ("attachment", "Attachment"), - ], + thread = models.ForeignKey( + Thread, on_delete=models.CASCADE, related_name="messages" ) - count = models.IntegerField(default=0) - - -DocketEntryShort.create_tags_model() + data = models.JSONField(default=dict) + num_tokens = models.IntegerField(default=0) + is_compact = models.BooleanField(default=False) + hidden = models.BooleanField(default=False) diff --git a/clx/app/prompts.py b/clx/app/prompts.py new file mode 100644 index 0000000..86080db --- /dev/null +++ b/clx/app/prompts.py @@ -0,0 +1,152 @@ +PROJECT_UNDERSTANDING = """ +# Step 1: Understand the data + +Search the project data to understand what this project is about. Start with \ +a broad sample (empty query, ~20 results), but if documents are short, pull \ +more to get a representative picture. + +# Step 2: Clarify your understanding of the project + +Using the project name and data, form your best hypothesis about the project's \ +goals and domain. Then ask the user a series of targeted questions to validate \ +and deepen your understanding. Keep asking follow-ups until you have a thorough \ +grasp of the project's aims, scope, and any nuances. + +# Step 3: Update the project instructions + +Once you have a solid understanding, update the project instructions with a \ +detailed overview covering the project's purpose, domain, data characteristics, \ +and any important context. + +> Note: the active label in your system prompt may be one of many. Focus your \ +questions on the project as a whole, not the currently active label. +""" + +LABEL_UNDERSTANDING = """ +# Step 1: Understand the data + +Search the project data to understand what this label should capture. Use the \ +label name, project instructions, and the data itself to form a hypothesis \ +about what documents should be included and excluded. + +Be thorough in your search — look for clear positives, clear negatives, and \ +edge cases. Try different query patterns to surface tricky examples that might \ +be ambiguous. + +# Step 2: Clarify your understanding of the label + +Then ask the user targeted questions to validate your understanding of the \ +label's criteria. Keep asking follow-ups until you can confidently distinguish \ +what belongs under this label and what doesn't. You can ask multiple questions \ +during each turn. Be very thorough. + +# Step 3: Update the label instructions + +Once ready, update the label instructions with detailed, specific annotation \ +guidelines. Include clear inclusion/exclusion criteria and address any edge \ +cases you identified. +""" + + +SAMPLING_STRATEGY = """ +# Step 1: Get a sense of training set size + +The expected size for the initial training set should be mentioned in the \ +project instructions. If it isn't, ask the user and update the project instructions \ +to reflect their answer. + +# Step 2: Come up with minimal and likely heuristics + +Search through the data an try to come up with two types of queries: + +- A minimal heuristic: This should be broad enough such that no positive example would \ +ever plausibly be excluded by the query. If the language of positive examples is \ +such that it would not be possible to scope examples with keyword conditions, \ +then you might leave this blank. + +- A likely heuristic: This should be narrow enough such that it catches some obvious \ +positive examples. It does not need to be perfect or complete, but it should catch \ +many easy positives. + +# Step 3: Store the heuristics in the label instructions + +Add a note detailing the function of the heuristics and their queries to the label \ +instructions. + +# Step 4: Create the initial sample + +Sample approximately 1/3 of the expected training set size from three buckets: + +- Things that do not satisfy the minimal heuristic. +- Things that satisfy the minimal heuristic but not the likely heuristic. +- Things that satisfy the likely heuristic. + +Note: When sampling based on a search, you can set num_examples to a much higher \ +number than the num_results used for the search tool. For example, if you wanted to \ +sample in 1000 examples, you might do a search with num_results=5 and then sample 1000 \ +examples from that search. This is encouraged so that you don't need to pull literally \ +every example into context. + +# Step 5: Target specific language + +You should perform additional queries to sample in any specific language that is discussed \ +in the label instructions. How many examples will depend on dataset size, but feel free to \ +grow the dataset size by ~20% with the stuff you pull in. Be thorough to capture any edge cases +whose representation needs boosted. + +# Step 6: Target unknown language + +Try to come up with queries that exclude as many things that are easily targetable. The goal here \ +is to sample in even more underrepresented language that will not be easily targeted by the other \ +queries. These should be examples that live between the minimal and likely heuristics, but which \ +are even more narrowly targeted. +""" + +ANNOTATE = """ +# Step 1: Annotate remaining examples + +Annotate all unannotated training examples for this label. Search for \ +unannotated documents (use annotation='none', query=None) and classify each one as 'yes', \ +'no', or 'skip' based on the label instructions. + +Work through the unannotated examples in batches of up to 100 examples. For each batch, \ +read the documents carefully and apply the label criteria consistently. Use 'yes' for \ +clear matches, 'no' for clear non-matches, and 'skip' only for documents that \ +are genuinely ambiguous or where you cannot make a confident determination. + +If the label instructions are unclear or you encounter edge cases not covered \ +by the guidelines, use your tool to ask the user for clarification so that you \ +can update the instructions before proceeding. + +After each batch, call the ClearToolHistory tool to make some extra room for context. + +# Step 2: Complete after all annotations or 3 batches + +You should continue annotating until all unannotated examples have been annotated or \ +until you've processed 3 batches. No matter what, do not end your turn without calling \ +the CompleteTask tool. Failing to do so will halt the process and require user intervention, \ +which we want to avoid (unless you are asking the user for feedback / clarification). + +> Note: If asked to compact your memory, you don't need to store specific annotations or \ +document IDs. +""" + + +prompt_registry = { + "project_understanding": { + "name": "Project Understanding", + "content": PROJECT_UNDERSTANDING.strip(), + }, + "label_understanding": { + "name": "Label Understanding", + "content": LABEL_UNDERSTANDING.strip(), + }, + "sampling_strategy": { + "name": "Sampling Strategy", + "content": SAMPLING_STRATEGY.strip(), + }, + "annotate": { + "name": "Annotate", + "content": ANNOTATE.strip(), + }, +} diff --git a/clx/app/search.py b/clx/app/search.py new file mode 100644 index 0000000..5865bc7 --- /dev/null +++ b/clx/app/search.py @@ -0,0 +1,300 @@ +from __future__ import annotations + +from typing import Literal + +from django.db.models import F as models_F +from django.db.models import OuterRef, Q, Subquery +from django.db.models.fields import Field as DjangoField +from django.db.models.lookups import Lookup +from postgres_copy import CopyManager, CopyQuerySet +from pydantic import BaseModel + +# ── Custom ILIKE lookup (uses trigram GIN index) ──────────── + + +class ILike(Lookup): + """Case-insensitive LIKE using ILIKE — supported by pg_trgm GIN index.""" + + lookup_name = "ilike" + + def as_sql(self, compiler, connection): + lhs, lhs_params = self.process_lhs(compiler, connection) + rhs, rhs_params = self.process_rhs(compiler, connection) + return f"{lhs} ILIKE {rhs}", [*lhs_params, *rhs_params] + + +DjangoField.register_lookup(ILike) + +# ── Query Schema ───────────────────────────────────────────── + + +class Contains(BaseModel): + type: Literal["contains"] = "contains" + value: str + + +class StartsWith(BaseModel): + type: Literal["startsWith"] = "startsWith" + value: str + + +class Not(BaseModel): + type: Literal["not"] = "not" + query: Query + + +class Or(BaseModel): + type: Literal["or"] = "or" + queries: list[Query] + + +class And(BaseModel): + type: Literal["and"] = "and" + queries: list[Query] + + +Query = Contains | StartsWith | Not | Or | And + +Not.model_rebuild() +Or.model_rebuild() +And.model_rebuild() + + +# ── Query String Parser ────────────────────────────────────── + + +def parse_query(text: str) -> dict: + """Parse a shorthand query string into a query dict. + + Syntax: + , = AND + | = OR + ~ = NOT + ^ = STARTSWITH + () = grouping + AND binds tighter than OR, so `A, B | C` = `A AND (B OR C)` + ...wait, the user said the opposite: "ands scoped outside ors" + meaning `A, B | C` = `A AND (B OR C)`. + Actually re-reading: "commas=AND ... ands as scoped outside of ors, + so A, B | C means A and (B or C)". + This means AND has *lower* precedence than OR. + """ + tokens = _tokenize(text) + if not tokens: + return {"type": "contains", "value": ""} + parser = _Parser(tokens) + result = parser.parse_expr() + if parser.peek() is not None: + raise ValueError(f"Unexpected token: {parser.peek()}") + return result + + +def _tokenize(text: str) -> list[str]: + tokens = [] + i = 0 + while i < len(text): + c = text[i] + if c in ",|~^()": + tokens.append(c) + i += 1 + elif c.isspace(): + i += 1 + else: + j = i + while j < len(text) and text[j] not in ",|~^()": + j += 1 + tokens.append(text[i:j].strip()) + i = j + return tokens + + +class _Parser: + def __init__(self, tokens: list[str]): + self.tokens = tokens + self.pos = 0 + + def peek(self) -> str | None: + return self.tokens[self.pos] if self.pos < len(self.tokens) else None + + def consume(self, expected: str | None = None) -> str: + tok = self.peek() + if tok is None: + raise ValueError(f"Unexpected end of input (expected {expected})") + if expected and tok != expected: + raise ValueError(f"Expected '{expected}', got '{tok}'") + self.pos += 1 + return tok + + def parse_expr(self) -> dict: + return self.parse_and() + + def parse_and(self) -> dict: + parts = [self.parse_or()] + while self.peek() == ",": + self.consume(",") + parts.append(self.parse_or()) + return ( + parts[0] if len(parts) == 1 else {"type": "and", "queries": parts} + ) + + def parse_or(self) -> dict: + parts = [self.parse_unary()] + while self.peek() == "|": + self.consume("|") + parts.append(self.parse_unary()) + return ( + parts[0] if len(parts) == 1 else {"type": "or", "queries": parts} + ) + + def parse_unary(self) -> dict: + if self.peek() == "~": + self.consume("~") + return {"type": "not", "query": self.parse_unary()} + return self.parse_primary() + + def parse_primary(self) -> dict: + if self.peek() == "(": + self.consume("(") + expr = self.parse_expr() + self.consume(")") + return expr + if self.peek() == "^": + self.consume("^") + return {"type": "startsWith", "value": self.consume()} + return {"type": "contains", "value": self.consume()} + + +# ── Django Q Builder ───────────────────────────────────────── + + +def build_q(query: dict) -> Q: + """Convert a query dict into a Django Q object. + + Uses ILIKE for contains (trigram GIN index compatible). + startsWith uses text_prefix for better index utilization. + """ + match query["type"]: + case "contains": + return Q(text__ilike=f"%{query['value']}%") + case "startsWith": + return Q(text_prefix__ilike=f"{query['value']}%") + case "not": + return ~build_q(query["query"]) + case "or": + result = Q() + for sub in query["queries"]: + result |= build_q(sub) + return result + case "and": + result = Q() + for sub in query["queries"]: + result &= build_q(sub) + return result + case _: + raise ValueError(f"Unknown query type: {query['type']}") + + +# ── Custom QuerySet & Manager ──────────────────────────────── + + +class SearchQuerySet(CopyQuerySet): + def text_query(self, query: dict) -> SearchQuerySet: + """Filter using a query dict (matching the Query schema).""" + return self.filter(build_q(query)) + + def query_string(self, qs: str) -> SearchQuerySet: + """Filter using the shorthand query string syntax.""" + return self.text_query(parse_query(qs)) + + def training_examples(self, label_id: str) -> SearchQuerySet: + """Filter to documents in a label's training set.""" + return self.filter(label_documents__label_id=label_id) + + def with_annotation( + self, label_id: str, source: str = "agent" + ) -> SearchQuerySet: + """Annotate each document with its classification value for label+source. + + Adds an `annotation_value` field (str or None) to each row. + Uses a single subquery — no N+1. + """ + from clx.app.models import ClassificationAnnotation + + return self.annotate( + annotation_value=Subquery( + ClassificationAnnotation.objects.filter( + label_document__document=OuterRef("pk"), + label_document__label_id=label_id, + source=source, + ).values("value")[:1] + ) + ) + + def filter_annotation( + self, label_id: str, value: str, source: str = "agent" + ) -> SearchQuerySet: + """Filter documents by annotation value for a label+source. + + value: 'yes', 'no', 'skip', 'none' (unannotated), or 'any' (has annotation). + """ + qs = self.training_examples(label_id).with_annotation(label_id, source) + if value == "none": + return qs.filter(annotation_value__isnull=True) + if value == "any": + return qs.filter(annotation_value__isnull=False) + return qs.filter(annotation_value=value) + + def with_prediction(self, label_id: str) -> SearchQuerySet: + """Annotate each document with its prediction value and confidence. + + Adds `prediction_value` (str or None) and `prediction_confidence` + (float or None) fields to each row. + """ + from clx.app.models import LabelDocument + + return self.annotate( + prediction_value=Subquery( + LabelDocument.objects.filter( + document=OuterRef("pk"), + label_id=label_id, + ).values("prediction")[:1] + ), + prediction_confidence_value=Subquery( + LabelDocument.objects.filter( + document=OuterRef("pk"), + label_id=label_id, + ).values("prediction_confidence")[:1] + ), + ) + + def filter_prediction(self, label_id: str, value: str) -> SearchQuerySet: + """Filter documents by prediction value for a label. + + value: 'yes', 'no', 'any', or 'disagree' (prediction != agent annotation). + """ + from clx.app.models import LabelDocument + + qs = self.training_examples(label_id) + qs = qs.annotate( + prediction_value=Subquery( + LabelDocument.objects.filter( + document=OuterRef("pk"), + label_id=label_id, + ).values("prediction")[:1] + ) + ) + if value == "disagree": + qs = qs.with_annotation(label_id, "agent") + return qs.filter( + prediction_value__isnull=False, + annotation_value__isnull=False, + ).exclude(prediction_value=models_F("annotation_value")) + if value == "any": + return qs.exclude(prediction_value="").filter( + prediction_value__isnull=False + ) + return qs.filter(prediction_value=value) + + +class SearchManager(CopyManager.from_queryset(SearchQuerySet)): + pass diff --git a/clx/app/search_utils.py b/clx/app/search_utils.py deleted file mode 100644 index 712ba43..0000000 --- a/clx/app/search_utils.py +++ /dev/null @@ -1,639 +0,0 @@ -import csv -import random -from io import StringIO - -import pandas as pd -import simplejson as json -from django.apps import apps -from django.contrib.postgres.fields import ArrayField -from django.contrib.postgres.indexes import GinIndex -from django.db import connection, models, transaction -from django.db.models import Q -from django.utils import timezone -from pgvector.django import ( - CosineDistance, - HalfVectorField, - HnswIndex, -) -from postgres_copy import CopyManager, CopyQuerySet -from pydantic import BaseModel as PydanticModel - -from clx import generate_hash -from clx.llm import batch_embed - - -# Pydantic Models -class TagParams(PydanticModel): - any: list[int] = [] - all: list[int] = [] - not_any: list[int] = [] - not_all: list[int] = [] - - -class SearchParams(PydanticModel): - heuristic_bucket: str | None = None - trainset_split: str | None = None - predictor_value: str | None = None - annotation_value: str | None = None - review_disagreements: bool | None = None - tags: TagParams = TagParams() - querystring: str | None = None - - -class SearchQuery(PydanticModel): - active_label_id: int | None = None - params: SearchParams = SearchParams() - sort: list[str] = ["shuffle_sort", "id"] - semantic_sort: str | list[float] | None = None - page: int = 1 - page_size: int = 100 - count: bool = False - - -# QuerySets -class SearchQuerySet(CopyQuerySet): - """QuerySet for search queries.""" - - def batch_df(self, *columns, batch_size=1000): - last_id = None - self = self.order_by("id") - while 1: - if last_id is not None: - self = self.filter(id__gt=last_id) - data = pd.DataFrame(self.values(*columns)[:batch_size]) - if len(data) == 0: - break - yield data - last_id = data["id"].max() - - def querystring(self, value): - """Apply a querystring to the query. - - For querystrings we will do an exact substring match on terms / phrases. - Commas will be treated as AND operators. - Bars will be treated as OR operators. - Tildes will be treated as NOT operators. - Carets will be treated as startswith operators. - We will always assume that ORs are nested in ANDs. - """ - if value is not None: - assert isinstance(value, str), "Querystring must be a string" - and_condition = None - for and_part in value.split(","): - or_condition = None - for or_part in and_part.split("|"): - or_part = or_part.strip() - negated = False - if or_part.startswith("~"): - or_part = or_part[1:].strip() - negated = True - if or_part.startswith("^"): - or_part = or_part[1:].strip() - condition = Q(text_prefix__istartswith=or_part) - else: - condition = Q(text__icontains=or_part.strip()) - if negated: - condition = ~condition - if or_condition is None: - or_condition = condition - else: - or_condition |= condition - if and_condition is None: - and_condition = or_condition - else: - and_condition &= or_condition - self = self.filter(and_condition) - return self - - def tags(self, **params): - params = TagParams(**params).model_dump() - if params.get("any"): - self = self.filter(example_tags__tags__overlap=params["any"]) - if params.get("all"): - self = self.filter(example_tags__tags__contains=params["all"]) - if params.get("not_any"): - self = self.exclude(example_tags__tags__overlap=params["not_any"]) - if params.get("not_all"): - self = self.exclude(example_tags__tags__contains=params["not_all"]) - return self - - def semantic_sort(self, value): - """Apply a semantic sort to the query.""" - if isinstance(value, str): - value = batch_embed([value], dimensions=96)[0] - assert isinstance(value, list), ( - "Semantic sort must be a string or list" - ) - assert len(value) == 96, "Semantic sort must be a list of 96 floats" - assert all(isinstance(v, float) for v in value), ( - "Semantic sort must be a list of floats" - ) - return self.annotate( - distance=CosineDistance("embedding", value) - ).order_by("distance") - - def search(self, **query): - """Search with params, pagination, and sorting.""" - project = self.model.get_project() - - # Prepare query - if query.get("params", {}).get("tags"): - query["params"]["tags"] = { - k: get_tag_ids(v, project.id) - for k, v in query["params"]["tags"].items() - if v - } - - # Validate query - query = SearchQuery(**query).model_dump() - self = self.annotate(tags=models.F("example_tags__tags")) - - active_label_id = query.get("active_label_id") - label = ( - project.labels.get(id=active_label_id) if active_label_id else None - ) - - # Apply heuristic bucket filter - heuristic_bucket = query["params"].get("heuristic_bucket") - if label is not None and heuristic_bucket: - if heuristic_bucket == "excluded": - self = label.excluded_query(self) - elif heuristic_bucket == "neutral": - self = label.neutral_query(self) - elif heuristic_bucket == "likely": - self = label.likely_query(self) - - # Apply trainset split filter - trainset_split = query["params"].get("trainset_split") - if label is not None and trainset_split: - if trainset_split == "train": - self = self.tags(any=[label.trainset_train_tag.id]) - elif trainset_split == "eval": - self = self.tags(any=[label.trainset_eval_tag.id]) - elif trainset_split == "both": - self = self.tags( - any=[ - label.trainset_train_tag.id, - label.trainset_eval_tag.id, - ] - ) - - # Apply predictor value filter - predictor_value = query["params"].get("predictor_value") - if label is not None and predictor_value: - self = self.tags( - any=[label.trainset_train_tag.id, label.trainset_eval_tag.id] - ) - if predictor_value == "true": - self = self.tags(any=[label.trainset_pred_tag.id]) - elif predictor_value == "false": - self = self.tags(not_any=[label.trainset_pred_tag.id]) - - # Apply manual annotation filter - annotation_value = query["params"].get("annotation_value") - if label is not None and annotation_value: - if annotation_value == "true": - self = self.tags(any=[label.anno_true_tag.id]) - elif annotation_value == "false": - self = self.tags(any=[label.anno_false_tag.id]) - elif annotation_value == "flag": - self = self.tags(any=[label.anno_flag_tag.id]) - elif annotation_value == "any": - self = self.tags( - any=[ - label.anno_true_tag.id, - label.anno_false_tag.id, - label.anno_flag_tag.id, - ] - ) - elif annotation_value == "none": - self = self.tags( - not_any=[ - label.anno_true_tag.id, - label.anno_false_tag.id, - label.anno_flag_tag.id, - ] - ) - - # Apply disagreements review filter - review_disagreements = query["params"].get("review_disagreements") - if label is not None and review_disagreements: - tag_ids = [label.trainset_pred_tag.id] - config_names = list( - label.fintunes.values_list("config_name", flat=True) - ) - for config_name in config_names: - tag_ids.append(label.get_trainset_finetune_tag(config_name).id) - if len(tag_ids) <= 1: - self = self.none() - else: - q_disagree = Q() - for i in tag_ids: - for j in tag_ids: - if i == j: - continue - q_disagree |= Q(example_tags__tags__contains=[i]) & ~Q( - example_tags__tags__contains=[j] - ) - self = self.filter(q_disagree) - - # Apply param filters - params = query["params"] - self = self.tags(**params.get("tags", {})) - self = self.querystring(params.get("querystring")) - - # Return count if requested - if query.get("count"): - return {"total": self.count()} - - # Apply sorting - if query.get("semantic_sort"): - self = self.semantic_sort(query["semantic_sort"]) - else: - self = self.order_by(*query["sort"]) - - # Select columns - cols = ["id", "text_hash", "text", "tags"] - self = self.values(*cols) - - # Apply pagination - self = self.page(query["page"], size=query["page_size"]) - data = list(self) - if label is not None and len(data): - data = pd.DataFrame(data) - trainset_examples = label.trainset_examples.filter( - text_hash__in=data["text_hash"].tolist() - ) - trainset_examples = trainset_examples.values( - "text_hash", "split", "pred", "reason" - ) - trainset_examples = pd.DataFrame(trainset_examples) - if len(trainset_examples): - trainset_examples = trainset_examples.drop_duplicates( - subset="text_hash" - ) - data = data.merge( - trainset_examples, on="text_hash", how="left" - ) - data = data.to_dict(orient="records") - data = json.loads(json.dumps(data, ignore_nan=True)) - return {"data": data} - - def page(self, page, size=100): - assert isinstance(page, int), "Page number must be an integer" - assert page > 0, "Page number must be greater than 0" - assert size > 0, "Page size must be greater than 0" - assert size <= 1000, "Page size must be less than 1000" - return self[size * (page - 1) : size * page] - - -# Queryset Managers -class SearchManager(CopyManager.from_queryset(SearchQuerySet)): - pass - - -# Abstract Models -class BaseModel(models.Model): - """Base model for all models""" - - id = models.BigAutoField(primary_key=True) - created_at = models.DateTimeField(auto_now_add=True) - updated_at = models.DateTimeField(auto_now=True) - - class Meta: - abstract = True - - -class SearchDocumentModelBase(models.base.ModelBase): - """Meta class for search document models.""" - - def __new__(cls, name, bases, attrs, **kwargs): - """Create a new search document model.""" - if "Meta" not in attrs: - project_id = attrs.get("project_id") - if project_id is None: - raise ValueError(f"{name} must define a project_id") - attrs["Meta"] = type( - "Meta", - (), - { - "db_table": f"project_{project_id}_doc", - "indexes": [ - models.Index( - fields=["shuffle_sort", "id"], - name=f"{project_id}_s_idx", - ), - models.Index( - fields=["text_prefix"], - name=f"{project_id}_pr_idx", - opclasses=["text_pattern_ops"], - ), - GinIndex( - fields=["text"], - name=f"{project_id}_trg_idx", - opclasses=["gin_trgm_ops"], - ), - HnswIndex( - fields=["embedding"], - name=f"{project_id}_hnsw_idx", - m=16, - ef_construction=64, - opclasses=["halfvec_cosine_ops"], - ), - ], - }, - ) - return super().__new__(cls, name, bases, attrs, **kwargs) - - -class SearchDocumentModel(BaseModel, metaclass=SearchDocumentModelBase): - """Search document model.""" - - project_id = None - finetune_configs = {} - main_finetune_config = None - - id = models.BigIntegerField(primary_key=True) - text = models.TextField() - text_prefix = models.CharField(max_length=50) - text_hash = models.CharField(max_length=255) - shuffle_sort = models.IntegerField() - embedding = HalfVectorField(dimensions=96) - - objects = SearchManager() - - def save(self, *args, **kwargs): - self.text_prefix = self.text[:50] - self.text_hash = generate_hash(self.text) - super().save(*args, **kwargs) - - @classmethod - def get_project(cls): - """Get the project for the search document.""" - return get_search_model_project(cls) - - @property - def project(self): - """Get the project for the search document.""" - return self.get_project() - - @classmethod - def create_tags_model(cls): - model_name = f"{cls.__name__}Tags" - - attrs = { - "__module__": cls.__module__, - "is_tag_model": True, - "project_id": cls.project_id, - "id": models.OneToOneField( - cls, - on_delete=models.CASCADE, - primary_key=True, - db_column="id", - related_name="example_tags", - ), - "tags": ArrayField( - models.BigIntegerField(), default=list, blank=True - ), - "objects": CopyManager(), - "get_project": classmethod(get_search_model_project), - "project": property(lambda self: self.get_project()), - "Meta": type( - "Meta", - (), - { - "db_table": f"project_{cls.project_id}_tags", - "indexes": [ - GinIndex( - fields=["tags"], - name=f"{cls.project_id}_t_gin", - ), - ], - }, - ), - } - - TagsModel = type(model_name, (models.Model,), attrs) - return TagsModel - - @classmethod - def guarantee_tags_rows(cls): - q = cls.objects.filter(example_tags__isnull=True) - if q.exists(): - for data in q.batch_df("id", batch_size=500000): - tags_model = cls.get_project().get_tags_model() - f = StringIO() - data.to_csv(f, index=False) - f.seek(0) - tags_model.objects.from_csv( - f, - static_mapping={"tags": "{}"}, - ignore_conflicts=True, - ) - - @classmethod - def bulk_replace_tag(cls, tag, ids): - """Bulk replace a tags for a table.""" - tag_id = get_tag_ids([tag], cls.project_id)[0] - tags_table = cls.get_project().get_tags_model()._meta.db_table - added = removed = 0 - - with transaction.atomic(), connection.cursor() as cur: - cur.execute( - "CREATE TEMP TABLE stage_tag_ids(example_id BIGINT) ON COMMIT DROP;" - ) - cur.execute("CREATE INDEX ON stage_tag_ids(example_id);") - - f = StringIO("".join(f"{i}\n" for i in ids)) - cur.copy_expert( - "COPY stage_tag_ids (example_id) FROM STDIN WITH (FORMAT CSV)", - f, - ) - - cur.execute( - f""" - UPDATE "{tags_table}" t - SET tags = array_cat(t.tags, ARRAY[%s]::bigint[]) - FROM stage_tag_ids s - WHERE t.id = s.example_id - AND NOT (t.tags @> ARRAY[%s]::bigint[]) - """, - [tag_id, tag_id], - ) - added = cur.rowcount - - cur.execute( - f""" - UPDATE "{tags_table}" t - SET tags = array_remove(t.tags, %s) - WHERE t.tags @> ARRAY[%s]::bigint[] - AND NOT EXISTS ( - SELECT 1 FROM stage_tag_ids s WHERE s.example_id = t.id - ) - """, - [tag_id, tag_id], - ) - removed = cur.rowcount - - return added, removed - - @classmethod - def bulk_update_column(cls, column, ids, values, id_column="id"): - """Bulk update column values by ID.""" - assert len(ids) == len(values), "ids and values must match in length" - - field = cls._meta.get_field(column) - field_type = get_pg_type(field) - id_type = get_pg_type(cls._meta.get_field(id_column)) - - table = cls._meta.db_table - - f = StringIO() - writer = csv.writer(f) - for k, v in zip(ids, values): - writer.writerow([k, "" if v is None else v]) - f.seek(0) - with transaction.atomic(), connection.cursor() as cur: - cur.execute( - f"CREATE TEMP TABLE stage_updates(id {id_type}, val text) ON COMMIT DROP;" - ) - cur.copy_expert( - "COPY stage_updates (id, val) FROM STDIN WITH (FORMAT CSV)", - f, - ) - cur.execute( - f""" - UPDATE "{table}" t - SET {column} = - CASE WHEN s.val = '' THEN NULL ELSE s.val::{field_type} END - FROM stage_updates s - WHERE t.{id_column} = s.id - """ - ) - updated = cur.rowcount - return updated - - @classmethod - def bulk_insert(cls, data, **kwargs): - """Bulk insert data into the model.""" - if "id" not in data.columns: - start_id = 1 - if cls.objects.exists(): - start_id = cls.objects.order_by("-id").first().id + 1 - data["id"] = range(start_id, start_id + len(data)) - data = data.dropna(subset=["text"]) - data["text"] = data["text"].str.strip() - data = data[data["text"].apply(len) > 0] - data["text_prefix"] = data["text"].apply(lambda x: x[:50]) - data["text_hash"] = data["text"].apply(generate_hash) - data["shuffle_sort"] = data["text_hash"].apply( - lambda x: random.randint(0, 100000000) - ) - embeddings = data.copy().drop_duplicates(subset=["text_hash"])[ - ["text_hash", "text"] - ] - embeddings = cls.get_project().load_or_add_embeddings(embeddings)[ - ["text_hash", "embedding"] - ] - data = data.merge(embeddings, on="text_hash", how="left") - f = StringIO() - data.to_csv(f, index=False) - f.seek(0) - cls.objects.from_csv( - f, - static_mapping={ - "created_at": timezone.now(), - "updated_at": timezone.now(), - }, - **kwargs, - ) - cls.guarantee_tags_rows() - - def set_annotation(self, label, value): - """Set annotation tag for this example for the given label.""" - if isinstance(value, bool): - value = "true" if value else "false" - assert value is None or value in ["true", "false", "flag"], ( - "value must be 'true', 'false', 'flag', True, False, or None" - ) - - tags = self.example_tags - tag_ids = { - "true": label.anno_true_tag.id, - "false": label.anno_false_tag.id, - "flag": label.anno_flag_tag.id, - } - for tag_id in tag_ids.values(): - if tag_id in tags.tags: - tags.tags.remove(tag_id) - if value in tag_ids: - tags.tags.append(tag_ids[value]) - tags.save() - label.update_trainset_pred_counts() - - class Meta: - abstract = True - indexes = [] - - -# Utils -def get_search_model_project(cls): - """Get the project for a search document or search tags model.""" - if cls.project_id is not None: - from .models import Project - - return Project.objects.get(id=cls.project_id) - - -def get_pg_type(field): - """Get the PostgreSQL type for a field.""" - if isinstance(field, (models.IntegerField | models.BigIntegerField)): - pg_type = "bigint" - elif isinstance(field, (models.TextField | models.CharField)): - pg_type = "text" - else: - raise NotImplementedError(f"Unsupported field type: {type(field)}") - return pg_type - - -def get_tag_ids(tags, project_id): - from clx.models import LabelTag - - if all(isinstance(tag, int) for tag in tags): - return tags - elif all(isinstance(tag, LabelTag) for tag in tags): - return [tag.id for tag in tags] - elif all(isinstance(tag, str) for tag in tags): - tags = LabelTag.objects.filter( - slug__in=tags, label__project__id=project_id - ) - return [tag.id for tag in tags] - else: - raise ValueError( - "tags must be same type, either int, LabelTag, or tag slug string" - ) - - -def init_search_models(**kwargs): - """Create a project for each SearchDocumentModel. - - Then register the associated tags model. - """ - from .models import Project - - for model in apps.get_models(): - if issubclass(model, SearchDocumentModel): - project, created = Project.objects.get_or_create( - id=model.project_id, - model_name=model.__name__, - ) - if created: - project.name = model.__name__ - project.save() - - for model in apps.get_models(): - if hasattr(model, "is_tag_model") and model.is_tag_model: - project = Project.objects.get(id=model.project_id) - if project.tags_model_name != model.__name__: - project.tags_model_name = model.__name__ - project.save() diff --git a/clx/app/templates/base.html b/clx/app/templates/base.html index 4f0d7b0..cca3744 100644 --- a/clx/app/templates/base.html +++ b/clx/app/templates/base.html @@ -3,16 +3,105 @@
-