forked from tekaratzas/RustGPT
-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathmain.rs
More file actions
145 lines (119 loc) · 4.54 KB
/
main.rs
File metadata and controls
145 lines (119 loc) · 4.54 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
use std::io::Write;
use ::llm::{EMBEDDING_DIM, HIDDEN_DIM, MAX_SEQ_LEN};
use dataset_loader::{Dataset, DatasetType};
use crate::{
embeddings::Embeddings, llm::LLM, output_projection::OutputProjection,
transformer::TransformerBlock, vocab::Vocab,
};
mod adam;
mod dataset_loader;
mod embeddings;
mod feed_forward;
mod layer_norm;
mod llm;
mod output_projection;
mod self_attention;
mod transformer;
mod vocab;
fn main() {
// Mock input - test conversational format
let string = String::from("User: How do mountains form?");
let dataset = Dataset::new(
String::from("data/pretraining_data.json"),
String::from("data/chat_training_data.json"),
DatasetType::JSON,
); // Placeholder, not used in this example
// Extract all unique words from training data to create vocabulary
let mut vocab_set = std::collections::HashSet::new();
// Process all training examples for vocabulary
// First process pre-training data
Vocab::process_text_for_vocab(&dataset.pretraining_data, &mut vocab_set);
// Then process chat training data
Vocab::process_text_for_vocab(&dataset.chat_training_data, &mut vocab_set);
let mut vocab_words: Vec<String> = vocab_set.into_iter().collect();
vocab_words.sort(); // Sort for deterministic ordering
let vocab_words_refs: Vec<&str> = vocab_words.iter().map(|s: &String| s.as_str()).collect();
let vocab = Vocab::new(vocab_words_refs);
let transformer_block_1 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM);
let transformer_block_2 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM);
let transformer_block_3 = TransformerBlock::new(EMBEDDING_DIM, HIDDEN_DIM);
let output_projection = OutputProjection::new(EMBEDDING_DIM, vocab.words.len());
let embeddings = Embeddings::new(vocab.clone());
let mut llm = LLM::new(
vocab,
vec![
Box::new(embeddings),
Box::new(transformer_block_1),
Box::new(transformer_block_2),
Box::new(transformer_block_3),
Box::new(output_projection),
],
);
println!("\n=== MODEL INFORMATION ===");
println!("Network architecture: {}", llm.network_description());
println!(
"Model configuration -> max_seq_len: {}, embedding_dim: {}, hidden_dim: {}",
MAX_SEQ_LEN, EMBEDDING_DIM, HIDDEN_DIM
);
println!("Total parameters: {}", llm.total_parameters());
println!("\n=== BEFORE TRAINING ===");
println!("Input: {}", string);
println!("Output: {}", llm.predict(&string));
println!("\n=== PRE-TRAINING MODEL ===");
println!(
"Pre-training on {} examples for {} epochs with learning rate {}",
dataset.pretraining_data.len(),
100,
0.0005
);
let pretraining_examples: Vec<&str> = dataset
.pretraining_data
.iter()
.map(|s| s.as_str())
.collect();
let chat_training_examples: Vec<&str> = dataset
.chat_training_data
.iter()
.map(|s| s.as_str())
.collect();
llm.train(pretraining_examples, 100, 0.0005);
println!("\n=== INSTRUCTION TUNING ===");
println!(
"Instruction tuning on {} examples for {} epochs with learning rate {}",
dataset.chat_training_data.len(),
100,
0.0001
);
llm.train(chat_training_examples, 100, 0.0001); // Much lower learning rate for stability
println!("\n=== AFTER TRAINING ===");
println!("Input: {}", string);
let result = llm.predict(&string);
println!("Output: {}", result);
println!("======================\n");
// Interactive mode for user input
println!("\n--- Interactive Mode ---");
println!("Type a prompt and press Enter to generate text.");
println!("Type 'exit' to quit.");
let mut input = String::new();
loop {
// Clear the input string
input.clear();
// Prompt for user input
print!("\nEnter prompt: ");
std::io::stdout().flush().unwrap();
// Read user input
std::io::stdin()
.read_line(&mut input)
.expect("Failed to read input");
// Trim whitespace and check for exit command
let trimmed_input = input.trim();
if trimmed_input.eq_ignore_ascii_case("exit") {
println!("Exiting interactive mode.");
break;
}
// Generate prediction based on user input with "User:" prefix
let formatted_input = format!("User: {}", trimmed_input);
let prediction = llm.predict(&formatted_input);
println!("Model output: {}", prediction);
}
}