-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtest_single_sample.py
More file actions
123 lines (98 loc) · 3.71 KB
/
test_single_sample.py
File metadata and controls
123 lines (98 loc) · 3.71 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
#!/usr/bin/env python3
"""
Test single sample to ensure identical generation between our framework and direct generation.
"""
import os
import sys
import json
import time
from pathlib import Path
# Add the src directory to the path
sys.path.append(str(Path(__file__).parent / "src"))
from models.generic_model import GenericModel
from adaptive.adaptive_cot import AdaptiveCoT
def test_single_sample():
"""Test with a single GSM8K sample to ensure identical generation."""
print("🔬 Single Sample Test: Ensuring Identical Generation")
print("=" * 60)
# Load model
print("🔧 Loading model...")
model = GenericModel(
model_name="/raid/LLM/llama3.1-8b-instruct",
config={}
)
model.load_model()
print("✅ Model loaded successfully")
# Test problem
problem = "Janet's ducks lay 16 eggs per day. She eats three for breakfast every morning and bakes muffins for her friends every day with four. She sells the remainder at the farmers' market daily for $2 per fresh duck egg. How much in dollars does she make every day at the farmers' market?"
print(f"\n📝 Problem: {problem}")
print("=" * 60)
# Test our framework (no few-shot, single branch)
print("\n🔧 Testing Our Framework (single branch, no few-shot)")
print("-" * 50)
config = {
"adaptive_branching": False,
"min_branches": 1,
"max_branches": 1,
"default_branches": 1,
"num_fewshot": 0,
"temperature": 0.0,
"top_p": 1.0,
"max_tokens": 512,
}
adaptive_cot = AdaptiveCoT(model, config)
start_time = time.time()
result = adaptive_cot.solve_problem(problem, max_parallel_branches=1)
our_time = time.time() - start_time
print(f"\n📊 Our Framework Result:")
print(f" Answer: {result['final_answer']}")
print(f" Time: {our_time:.2f}s")
print(f" Reasoning: {result['reasoning_paths'][0][:200]}...")
# Test direct generation
print("\n🔧 Testing Direct Generation")
print("-" * 50)
prompt = f"Q: {problem}\nA:"
start_time = time.time()
generated = model.generate(
prompt,
max_tokens=512,
temperature=0.0,
top_p=1.0,
do_sample=False,
num_return_sequences=1
)
direct_time = time.time() - start_time
if isinstance(generated, list):
generated = generated[0]
# Apply same stop sequence handling as our framework
stop_sequences = ["Q:", "</s>", "<|im_end|>", "\n\nQ:", "://"]
for stop_seq in stop_sequences:
if stop_seq in generated:
generated = generated.split(stop_seq)[0]
generated = generated.strip()
print(f"\n📊 Direct Generation Result:")
print(f" Generated: {generated[:200]}...")
print(f" Time: {direct_time:.2f}s")
# Compare results
print("\n🔍 Comparison:")
print("-" * 50)
print(f"Our Framework: {result['reasoning_paths'][0]}")
print(f"Direct Gen: {generated}")
print(f"Identical: {result['reasoning_paths'][0] == generated}")
# Test answer extraction
print("\n🔍 Answer Extraction Test:")
print("-" * 50)
# Test the problematic cases
test_cases = [
"The best answer is $70,000. $70,000.00. $70,000.00.",
"The best answer is 20. 20 cups of feed. 20 cups of feed.",
"The answer is 18.",
"#### 18",
"Final answer: 540",
]
for i, test_text in enumerate(test_cases, 1):
extracted = adaptive_cot._extract_single_answer(test_text)
print(f"Test {i}: '{test_text}' -> '{extracted}'")
return result, generated
if __name__ == "__main__":
result, generated = test_single_sample()