-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathtrain_vector_embedding.py
More file actions
130 lines (102 loc) · 4.67 KB
/
train_vector_embedding.py
File metadata and controls
130 lines (102 loc) · 4.67 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
import json
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import Dataset, DataLoader
from transformers import GPT2Model, GPT2TokenizerFast
import os
from concurrent.futures import ThreadPoolExecutor
from tqdm import tqdm
import numpy as np
# GPU 사용 설정
device = torch.device("cuda" if torch.cuda.is_available() else "cpu")
print(f"Using device: {device}")
# 커스텀 토크나이저 로드
tokenizer_dir = '.' # tokenizer.json 파일이 있는 디렉토리 경로
tokenizer = GPT2TokenizerFast.from_pretrained(tokenizer_dir)
# GPT2 모델 로드 (임베딩 레이어만 사용)
try:
gpt2_model = GPT2Model.from_pretrained('gpt2')
embedding_layer = gpt2_model.wte.to(device)
except Exception as e:
print(f"Error loading GPT2 model: {str(e)}")
exit(1)
# 새로운 토큰에 대한 임베딩 초기화
model_vocab = set(tokenizer.convert_ids_to_tokens(range(gpt2_model.config.vocab_size)))
tokenizer_vocab = set(tokenizer.get_vocab().keys())
new_tokens = tokenizer_vocab - model_vocab
print(f"Found {len(new_tokens)} new tokens that require embeddings.")
new_embeddings = nn.Embedding(len(new_tokens), embedding_layer.embedding_dim).to(device)
# 데이터셋 클래스 정의
class CorpusDataset(Dataset):
def __init__(self, directory, tokenizer):
self.tokenizer = tokenizer
self.files = [os.path.join(directory, f) for f in os.listdir(directory) if f.lower().endswith('.tex')]
self.data = []
print(f"Found {len(self.files)} TeX files in the directory.")
with ThreadPoolExecutor() as executor:
self.data = list(tqdm(executor.map(self.process_file, self.files), total=len(self.files)))
# 빈 리스트 제거
self.data = [item for item in self.data if item]
print(f"Processed {len(self.data)} non-empty files.")
def process_file(self, file_path):
try:
with open(file_path, 'r', encoding='utf-8') as f:
content = f.read()
encoded = self.tokenizer.encode(content, truncation=True, max_length=512)
if encoded:
return encoded
except UnicodeDecodeError:
print(f"Warning: Unable to read {file_path} with UTF-8 encoding. Trying with 'cp949'...")
try:
with open(file_path, 'r', encoding='cp949') as f:
content = f.read()
encoded = self.tokenizer.encode(content, truncation=True, max_length=512)
if encoded:
return encoded
except Exception as e:
print(f"Error processing file {file_path}: {str(e)}")
except Exception as e:
print(f"Error processing file {file_path}: {str(e)}")
return None
def __len__(self):
return len(self.data)
def __getitem__(self, idx):
tokens = torch.tensor(self.data[idx])
return tokens, len(tokens)
# Padding을 위한 collate function 정의
def collate_fn(batch):
sequences, lengths = zip(*batch)
max_len = max(lengths)
padded_sequences = torch.zeros((len(sequences), max_len), dtype=torch.long)
for i, seq in enumerate(sequences):
end = lengths[i]
padded_sequences[i, :end] = seq[:end]
return padded_sequences.to(device), torch.tensor(lengths).to(device)
# 데이터 로더 생성
corpus_path = r'C:\Users\wjdrb\Downloads\drive-download-20240719T063242Z-001\temp_train\2301'
dataset = CorpusDataset(corpus_path, tokenizer)
dataloader = DataLoader(dataset, batch_size=32, shuffle=True, collate_fn=collate_fn)
# 손실 함수와 옵티마이저 정의
criterion = nn.MSELoss()
optimizer = optim.Adam(new_embeddings.parameters(), lr=0.0001)
# 학습 루프
num_epochs = 100
for epoch in range(num_epochs):
for batch, lengths in tqdm(dataloader, desc=f"Epoch {epoch+1}/{num_epochs}"):
batch = batch.to(device)
# 기존 임베딩과 새 임베딩 결합
combined_embeddings = torch.cat([embedding_layer.weight, new_embeddings.weight])
# 입력에 대한 임베딩 계산
input_embeddings = combined_embeddings[batch]
# 다음 토큰 예측을 위한 타겟 설정
targets = torch.roll(input_embeddings, shifts=-1, dims=1)
# 손실 계산 및 역전파
loss = criterion(input_embeddings[:, :-1], targets[:, :-1])
optimizer.zero_grad()
loss.backward()
optimizer.step()
print(f"Epoch {epoch+1}/{num_epochs}, Loss: {loss.item()}")
# 학습된 임베딩 저장
torch.save(new_embeddings.state_dict(), 'new_embeddings.pth')
print("Training completed. New embeddings saved to 'new_embeddings.pth'.")