-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathapp.py
More file actions
217 lines (186 loc) · 7.63 KB
/
app.py
File metadata and controls
217 lines (186 loc) · 7.63 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
212
213
214
215
216
217
import os, io, re
from typing import Optional
from fastapi import FastAPI, HTTPException
from pydantic import BaseModel
from PIL import Image, ImageFile
import boto3
ImageFile.LOAD_TRUNCATED_IMAGES = True
AWS_REGION = os.getenv("AWS_REGION", "us-east-1")
VLM_MODEL_ID = os.getenv("VLM_MODEL_ID", "Qwen/Qwen2.5-VL-7B-Instruct")
VLM_LOAD_8BIT = os.getenv("VLM_LOAD_8BIT", "false").lower() in {"1","true","yes"}
# Simplified configuration - Qwen2.5-VL only
# ---------- S3 Loader ----------
s3 = boto3.client("s3", region_name=AWS_REGION)
def s3_image(s3_uri: str) -> Image.Image:
if not s3_uri.startswith("s3://"):
raise HTTPException(400, "s3_uri must start with s3://")
try:
bucket, key = s3_uri[5:].split("/", 1)
obj = s3.get_object(Bucket=bucket, Key=key)
return Image.open(io.BytesIO(obj["Body"].read())).convert("RGB")
except Exception as e:
raise HTTPException(400, f"Failed to fetch/decode image: {e}")
# ---------- Qwen2.5-VL Captioner ----------
class QwenCaptioner:
"""
Instruction VLM with a prompt & decoding recipe tuned for detailed but concise captions.
"""
def __init__(self, model_id: str, load_8bit: bool = False):
import torch
from transformers import AutoProcessor, AutoModelForImageTextToText
self.torch = torch
self.device = "cuda" if torch.cuda.is_available() else "cpu"
print(f"[boot] Loading Qwen2.5-VL model: {model_id}")
print(f"[boot] Device: {self.device}")
print(f"[boot] CUDA available: {torch.cuda.is_available()}")
if torch.cuda.is_available():
print(f"[boot] CUDA memory: {torch.cuda.get_device_properties(0).total_memory / 1e9:.1f}GB")
# Load processor with fast processor disabled to avoid warnings
self.processor = AutoProcessor.from_pretrained(
model_id,
trust_remote_code=True,
use_fast=False # Disable fast processor to avoid warnings
)
# Configure quantization and loading
quant_kwargs = {}
if load_8bit and self.device == "cuda":
try:
quant_kwargs = {
"load_in_8bit": True,
"device_map": "auto",
"low_cpu_mem_usage": True
}
print("[boot] Loading model in 8-bit with bitsandbytes")
except Exception as e:
print(f"[boot] 8-bit load failed: {e}; using fp16 instead")
quant_kwargs = {"low_cpu_mem_usage": True}
else:
quant_kwargs = {"low_cpu_mem_usage": True}
print("[boot] Starting model loading...")
self.model = AutoModelForImageTextToText.from_pretrained(
model_id,
trust_remote_code=True,
dtype=(torch.float16 if self.device == "cuda" else torch.float32),
**quant_kwargs
)
if not quant_kwargs.get("device_map"):
print("[boot] Moving model to device...")
self.model = self.model.to(self.device)
self.model.eval()
# Optimize CUDA settings
if self.device == "cuda":
try:
torch.backends.cuda.matmul.allow_tf32 = True
torch.backends.cudnn.allow_tf32 = True
torch.backends.cudnn.benchmark = True
print("[boot] CUDA optimizations enabled")
except Exception as e:
print(f"[boot] CUDA optimization warning: {e}")
print("[boot] Qwen2.5-VL model loaded successfully")
# Professional OpenAI-style captioning prompt
PROMPT = (
"Audience: graphic designers, photographers, creative directors using text search. "
"Task: Describe a single image in one stand-alone English sentence; ~30-50 words. "
"Always include: Any clearly identifiable famous person, landmark, brand, artwork, product model (proper nouns, exact spelling). "
"Stylistic or mood cues only if they are visually central. "
"Always avoid: Guessing when uncertain; if identity is unclear, name the generic class rather than a brand. "
"Camera metadata, hashtags, subjective opinions, filler, emojis."
)
def caption(self, pil_img: Image.Image) -> str:
"""
Decoding tuned for specificity:
- num_beams for coverage
- low temperature for precision
- no_repeat_ngram_size to avoid loops
"""
import torch
proc = self.processor
# Qwen expects chat-style messages with an image
messages = [
{"role": "system", "content": "You are a helpful visual caption assistant."},
{"role": "user", "content": [
{"type": "text", "text": self.PROMPT},
{"type": "image", "image": pil_img}
]}
]
inputs = proc.apply_chat_template(
messages, add_generation_prompt=True, tokenize=True, return_tensors="pt"
)
pixel_inputs = proc(images=pil_img, return_tensors="pt")
# move to device / match dtypes
dtype = next(self.model.parameters()).dtype
inputs = {k: v.to(self.device) for k, v in inputs.items()}
pixel_inputs = {k: (v.to(self.device, dtype=dtype) if v.dtype.is_floating_point else v.to(self.device))
for k, v in pixel_inputs.items()}
gen_kwargs = dict(
max_new_tokens=64,
num_beams=5,
do_sample=False,
length_penalty=1.1,
no_repeat_ngram_size=3,
)
with torch.no_grad():
out = self.model.generate(
**inputs, **pixel_inputs, **gen_kwargs
)
text = proc.batch_decode(out, skip_special_tokens=True)[0]
text = postprocess_caption(text)
return text
def postprocess_caption(text: str) -> str:
# Clean up chat prefix artifacts or stray quotes
t = re.sub(r"^\s*(assistant:|assistant|\"|“|”)+\s*", "", text.strip(), flags=re.IGNORECASE)
t = re.sub(r"\s+", " ", t)
# enforce one sentence-ish: keep to ~50 words for professional captions
words = t.split()
if len(words) > 50:
t = " ".join(words[:50]).rstrip(",;:") + "."
if not t.endswith((".", "!", "?")):
t += "."
# Capitalize first letter
if t and t[0].islower():
t = t[0].upper() + t[1:]
return t
# ---------- FastAPI ----------
app = FastAPI(title="Professional Image Captioner (Qwen2.5-VL-7B)")
print("[boot] init …")
# Add memory management
import gc
import torch
if torch.cuda.is_available():
torch.cuda.empty_cache()
gc.collect()
try:
CAPTIONER = QwenCaptioner(VLM_MODEL_ID, VLM_LOAD_8BIT)
print("[boot] Qwen2.5-VL loaded successfully")
except Exception as e:
print(f"[boot] Failed to load Qwen2.5-VL: {e}")
print("[boot] This might be due to insufficient memory or disk space")
raise e
print("[boot] ready.")
class CaptionRequest(BaseModel):
s3_uri: str
detailed: Optional[bool] = True # kept for compatibility; prompt already emphasizes detail
@app.get("/health")
def health():
import torch
return {
"ok": True,
"backend": "qwen2.5-vl-7b-instruct",
"device": "cuda" if torch.cuda.is_available() else "cpu",
"model": VLM_MODEL_ID,
"8bit_quantization": VLM_LOAD_8BIT
}
@app.post("/caption")
def caption(req: CaptionRequest):
try:
img = s3_image(req.s3_uri)
cap = CAPTIONER.caption(img)
return {
"caption": cap,
"s3_uri": req.s3_uri,
"model": VLM_MODEL_ID
}
except HTTPException as e:
raise e
except Exception as e:
raise HTTPException(500, str(e))