-
Notifications
You must be signed in to change notification settings - Fork 0
Expand file tree
/
Copy pathlabel.py
More file actions
111 lines (101 loc) · 3.6 KB
/
label.py
File metadata and controls
111 lines (101 loc) · 3.6 KB
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
import os
import warnings
import argparse
from glob import glob
from concurrent.futures import ThreadPoolExecutor
#
import numpy as np
import pandas as pd
from tqdm import tqdm
from PIL import Image
import face_recognition
from langchain_ollama import ChatOllama
from langchain_core.prompts import ChatPromptTemplate
#
from utils import (
read_config,
gather_samples,
image_to_base64,
create_output_model,
construct_system_message,
)
from tracker import Tracker
#
parser = argparse.ArgumentParser(description='Start the labeling process.')
parser.add_argument('data_dir', type=str, help='Path to data directory')
#
parser.add_argument('-c', '--config', type=str,
help='Path to config file.', default="config.yaml")
parser.add_argument('-o', '--output_dir', type=str,
help='Path to the output directory.', default="out")
parser.add_argument('-w', '--workers', type=int,
help='The number of workers for processing the samples', default=4)
parser.add_argument('-v', '--verbose', action='store_true',
help='Wether to print out info')
args = parser.parse_args()
#
os.makedirs(args.output_dir, exist_ok=True)
output_filepath = os.path.join(args.output_dir, "annotation.csv")
config = read_config(args.config)
tracker = Tracker(
output_dir=args.output_dir,
batch_size=args.workers,
)
llm = ChatOllama(model=config['configurations']['model'])
OutputModel = create_output_model(config['accessories'])
SYSTEM_MESSAGE = construct_system_message(config['accessories'])
#
def process_sample(sample_filepath: str):
image = Image.open(sample_filepath)
if config['configurations']['detect_faces']:
bboxes = face_recognition.face_locations(image)
else:
bboxes = [(0, image.width, image.height, 0)] # the whole image
#
image = np.array(image)
annotations = []
for person_id, bbox in enumerate(bboxes):
ymin, xmax, ymax, xmin = bbox
face_image = image[ymin:ymax, xmin:xmax]
face_image = Image.fromarray(face_image)
image_base64 = image_to_base64(face_image)
prompt = ChatPromptTemplate.from_messages([
("system", SYSTEM_MESSAGE),
("human", [
{"type": "text", "text": "Describe the following image:"},
{"type": "image_url", "image_url": f"data:image/jpeg;base64,{image_base64}"}
])
])
chain = prompt | llm.with_structured_output(OutputModel)
response = chain.invoke({})
annotations.append({
"filename": sample_filepath,
"person_id": person_id,
"xmin": xmin,
"ymin": ymin,
"xmax": xmax,
"ymax": ymax,
**response.model_dump(),
})
pd.DataFrame(annotations).to_csv(
output_filepath,
mode="a",
header=not os.path.exists(output_filepath),
index=False
)
if __name__ == "__main__":
all_samples = gather_samples(
data_dir=args.data_dir,
extensions=config['configurations']['image_extensions']
)
tracker.add_samples(all_samples)
num_total = len(all_samples)
num_pending = tracker.pending_count()
with tqdm(total=num_total, initial=(num_total-num_pending), desc="Processing items") as pbar:
while tracker.pending_count() > 0:
sample_paths = tracker.get_batch()
with ThreadPoolExecutor(max_workers=args.workers) as pool:
results = list(pool.map(process_sample, sample_paths))
for sample in sample_paths:
tracker.mark_done(sample)
pbar.update(1)