-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathsegment_predictor_cache.py
More file actions
212 lines (178 loc) · 9.87 KB
/
segment_predictor_cache.py
File metadata and controls
212 lines (178 loc) · 9.87 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import torch
from PIL import Image
from transformers import AutoProcessor, DynamicCache
import numpy as np
import torch.nn.functional as F
from model.qwen_changes import get_rope_index, SegQwenVL
import os
import json
import time
def find_image_patch_info(image_pad_id, input_ids: torch.Tensor):
"""
From the end to the beginning, find consecutive image_pad_id in the input tensor and return their count.
Parameters:
image_pad_id (int): The ID of the image padding token.
input_ids (torch.Tensor): The input tensor of IDs.
Returns:
int: The number of consecutive image patches.
Raises:
RuntimeError: If no image patches (<|image_pad|>) are found in input_ids.
"""
input_ids_list = input_ids.squeeze().tolist()
# Reverse the list to search from the end to the beginning
reversed_input_ids_list = input_ids_list[::-1]
try:
# Find the first occurrence of image_pad_id in the reversed list
start_idx_rev = reversed_input_ids_list.index(image_pad_id)
end_idx_rev = start_idx_rev
# Continue to find consecutive image_pad_id
while end_idx_rev + 1 < len(reversed_input_ids_list) and reversed_input_ids_list[
end_idx_rev + 1] == image_pad_id:
end_idx_rev += 1
num_patches = (end_idx_rev - start_idx_rev) + 1
return num_patches
except ValueError:
raise RuntimeError("No image patches (<|image_pad|>) found in input_ids.")
class GenerativeSegmenter:
def __init__(self, model_path: str, min_pixels, max_pixels, **kwargs):
min_pixels = min_pixels
max_pixels = max_pixels
self.device = kwargs.get("device_map", "cuda" if torch.cuda.is_available() else "cpu")
# --- New intelligent loading logic ---
adapter_config_path = os.path.join(model_path, "adapter_config.json")
if os.path.exists(adapter_config_path):
print(f"Detected PEFT adapter configuration: {adapter_config_path}. Will load base model first, then load adapter.")
# Read the base model path from the adapter configuration
with open(adapter_config_path, 'r', encoding='utf-8') as f:
adapter_config = json.load(f)
# Base model path, if not present in the config, you need to specify it manually
base_model_path = adapter_config.get("base_model_name_or_path")
if not base_model_path:
# ********************************************************************************
# ** Important: If adapter_config.json does not contain base_model_name_or_path,
# ** please manually specify the correct base model name or path here
# ** Based on your previous error messages, the base model is likely "Qwen/Qwen2-VL-7B-Instruct"
# ********************************************************************************
base_model_path = "Qwen/Qwen2-VL-7B-Instruct"
print(f"Warning: 'base_model_name_or_path' not found in adapter configuration. Using default base model: '{base_model_path}'")
# 1. Load the base model
print(f"Loading base model from '{base_model_path}'...")
self.model = SegQwenVL.from_pretrained(
base_model_path,
torch_dtype="auto",
trust_remote_code=True,
# attn_implementation="flash_attention_2",
**kwargs
)
self.processor = AutoProcessor.from_pretrained(base_model_path, trust_remote_code=True,
min_pixels=min_pixels, max_pixels=max_pixels)
self.tokenizer = self.processor.tokenizer
self._add_special_tokens()
# 2. Load the adapter
print(f"Loading adapter from '{model_path}'...")
self.model.load_adapter(model_path)
else:
print(f"No PEFT adapter detected. Loading full model directly from '{model_path}'.")
# Keep the original direct loading method
self.model = SegQwenVL.from_pretrained(
model_path,
torch_dtype="auto",
trust_remote_code=True,
# attn_implementation="flash_attention_2",
**kwargs
)
self.processor = AutoProcessor.from_pretrained(model_path, trust_remote_code=True, min_pixels=min_pixels,
max_pixels=max_pixels)
self.tokenizer = self.processor.tokenizer
self._add_special_tokens()
# --- Intelligent loading logic ends ---
TargetClass = type(self.model.model)
TargetClass.get_rope_index = get_rope_index
# Get key token IDs
self.yes_token_id = self.tokenizer.convert_tokens_to_ids("<|yes|>")
self.no_token_id = self.tokenizer.convert_tokens_to_ids("<|no|>")
self.seg_token_id = self.tokenizer.convert_tokens_to_ids("<|seg|>")
self.mask_token_id = self.tokenizer.convert_tokens_to_ids("<|mask|>")
self.image_pad_id = self.tokenizer.convert_tokens_to_ids('<|image_pad|>')
self.eos_token_id = self.tokenizer.eos_token_id
self.model.mask_token_id = self.mask_token_id
def _add_special_tokens(self):
special_tokens = {'additional_special_tokens': ["<|seg|>", "<|mask|>", "<|yes|>", "<|no|>"]}
num_added = self.tokenizer.add_special_tokens(special_tokens)
if num_added > 0:
print(f"Added {num_added} special tokens. Resizing model embedding layer...")
self.model.resize_token_embeddings(len(self.tokenizer))
# Check if the resized size matches your model's expectations
print(
f"Resized vocabulary size: {len(self.tokenizer)}, Model embedding layer size: {self.model.get_input_embeddings().weight.shape[0]}")
if self.tokenizer.pad_token_id is None:
self.tokenizer.pad_token_id = self.tokenizer.eos_token_id
@torch.no_grad()
def generate_with_segmentation(self, image: Image.Image, prompt: str):
messages = [{"role": "user", "content": [{"image": image}, {"text": prompt}]}]
text = self.processor.apply_chat_template(messages, tokenize=False, add_generation_prompt=True)
inputs = self.processor(text=[text], images=[image], return_tensors="pt")
merge_size = self.processor.image_processor.merge_size
inputs = {k: v.to(self.device) for k, v in inputs.items()}
prompt_len = inputs['input_ids'].shape[1]
image_grid_thw = inputs.get('image_grid_thw').to(self.device) # Qwen2.5-VL may use this key
attention_mask_raw = inputs['attention_mask'].to(self.device)
outputs = self.model.generate(
**inputs,
max_new_tokens=1024,
use_cache=True,
return_dict_in_generate=True,
eos_token_id=self.eos_token_id,
pad_token_id=self.tokenizer.pad_token_id
)
sequence = outputs.sequences[0]
full_past_key_values = outputs.past_key_values
# Find all <seg> token positions
seg_indices = torch.where(sequence == self.seg_token_id)[0].tolist()
all_segmentation_masks = []
seg_forward_times = [] # Initialize list to store times
if not seg_indices: # If there are no segmentation tasks
generated_ids = sequence[prompt_len:]
response_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
return None, response_text
num_patches = find_image_patch_info(self.image_pad_id, inputs['input_ids'])
# Iterate over each <seg> token and perform segmentation
for i, idx in enumerate(seg_indices):
sliced_len = idx + 1
attention_mask = attention_mask_raw[:, :sliced_len]
legacy_cache = full_past_key_values.to_legacy_cache()
# 2. Slice each tensor in the tuple
past_key_values_sliced = tuple(
(
key_layer[:, :, :sliced_len, :],
value_layer[:, :, :sliced_len, :]
)
for key_layer, value_layer in legacy_cache
)
past_key_values_sliced = DynamicCache.from_legacy_cache(past_key_values_sliced)
mask_query_ids = torch.full((1, num_patches), self.mask_token_id, dtype=torch.long, device=self.device)
mask_query_attention_mask = torch.ones((1, num_patches + sliced_len - attention_mask[0].sum()),
dtype=torch.long, device=self.device)
mask_query_attention_mask = torch.cat((attention_mask, mask_query_attention_mask), dim=1)
mask_grid_thw = image_grid_thw[-1].clone()
mask_grid_thw = mask_grid_thw.unsqueeze(0)
mask_pre_ids = sequence.clone().unsqueeze(0)
mask_ids = torch.cat([mask_pre_ids[0, :sliced_len], mask_query_ids[0]], dim=0).unsqueeze(0)
seg_forward_outputs = self.model(
input_ids=mask_ids,
attention_mask=mask_query_attention_mask,
image_grid_thw=image_grid_thw,
pixel_values=inputs['pixel_values'],
past_key_values=past_key_values_sliced,
return_dict=True,
do_classification=True
)
mask_logits = seg_forward_outputs.bi_logits[:, -num_patches:]
segmentation_preds = (mask_logits > 0).long().squeeze().cpu()
h_grid, w_grid = mask_grid_thw[0, 1:]
h_grid, w_grid = int(h_grid / merge_size), int(w_grid / merge_size)
segmentation_preds = segmentation_preds.view(h_grid, w_grid)
all_segmentation_masks.append(segmentation_preds)
generated_ids = sequence[prompt_len:]
response_text = self.tokenizer.decode(generated_ids, skip_special_tokens=True)
return all_segmentation_masks, response_text