Skip to content
This repository was archived by the owner on Dec 21, 2017. It is now read-only.

Commit dbb7994

Browse files
authored
Merge pull request #101 from Andrew62/master
Interactive DenseCap
2 parents d3e1564 + cf62d6a commit dbb7994

2 files changed

Lines changed: 767 additions & 0 deletions

File tree

attalos/imgtxt_algorithms/densecap/Interactive DenseCap.ipynb

Lines changed: 510 additions & 0 deletions
Large diffs are not rendered by default.
Lines changed: 257 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,257 @@
1+
require 'torch'
2+
require 'nn'
3+
require 'image'
4+
5+
require 'densecap.DenseCapModel'
6+
local utils = require 'densecap.utils'
7+
local box_utils = require 'densecap.box_utils'
8+
local vis_utils = require 'densecap.vis_utils'
9+
10+
11+
--[[
12+
Run a trained DenseCap model on images.
13+
14+
The inputs can be any one of:
15+
- a single image: use the flag '-input_image' to give path
16+
- a directory with images: use flag '-input_dir' to give dir path
17+
- MSCOCO split: use flag '-input_split' to identify the split (train|val|test)
18+
19+
The output can be controlled with:
20+
- max_images: maximum number of images to process. Set to -1 to process all
21+
- output_dir: use this flag to identify directory to write outputs to
22+
- output_vis: set to 1 to output images/json to the vis directory for nice viewing in JS/HTML
23+
--]]
24+
25+
26+
local cmd = torch.CmdLine()
27+
28+
-- Model options
29+
cmd:option('-checkpoint',
30+
'data/models/densecap/densecap-pretrained-vgg16.t7')
31+
cmd:option('-image_size', 720)
32+
cmd:option('-rpn_nms_thresh', 0.7)
33+
cmd:option('-final_nms_thresh', 0.3)
34+
cmd:option('-num_proposals', 1000)
35+
36+
-- Input settings
37+
cmd:option('-input_image', '',
38+
'A path to a single specific image to caption')
39+
cmd:option('-input_dir', '', 'A path to a directory with images to caption')
40+
cmd:option('-input_split', '',
41+
'A VisualGenome split identifier to process (train|val|test)')
42+
cmd:option('-input_batch_dir', '',
43+
'A path to a directory of text files. Each line in each file should contain a path to an image')
44+
45+
cmd:option('-save_processed_img', 0,
46+
'if 1 then densecap will write the tagged image')
47+
48+
-- Only used when input_split is given
49+
cmd:option('-splits_json', 'info/densecap_splits.json')
50+
cmd:option('-vg_img_root_dir', '', 'root directory for vg images')
51+
52+
-- Output settings
53+
cmd:option('-max_images', 100, 'max number of images to process')
54+
cmd:option('-output_dir', '')
55+
-- these settings are only used if output_dir is not empty
56+
cmd:option('-num_to_draw', 10, 'max number of predictions per image')
57+
cmd:option('-text_size', 2, '2 looks best I think')
58+
cmd:option('-box_width', 2, 'width of rendered box')
59+
cmd:option('-output_vis', 1,
60+
'if 1 then writes files needed for pretty vis into vis/ ')
61+
cmd:option('-output_vis_dir', 'vis/data')
62+
63+
-- Misc
64+
cmd:option('-gpu', 0)
65+
cmd:option('-use_cudnn', 1)
66+
local opt = cmd:parse(arg)
67+
68+
69+
function run_image(model, img_path, opt, dtype)
70+
71+
-- Load, resize, and preprocess image
72+
local img = image.load(img_path, 3)
73+
img = image.scale(img, opt.image_size):float()
74+
local H, W = img:size(2), img:size(3)
75+
local img_caffe = img:view(1, 3, H, W)
76+
img_caffe = img_caffe:index(2, torch.LongTensor{3, 2, 1}):mul(255)
77+
local vgg_mean = torch.FloatTensor{103.939, 116.779, 123.68}
78+
vgg_mean = vgg_mean:view(1, 3, 1, 1):expand(1, 3, H, W)
79+
img_caffe:add(-1, vgg_mean)
80+
81+
-- Run the model forward
82+
local boxes, scores, captions = model:forward_test(img_caffe:type(dtype))
83+
local boxes_xywh = box_utils.xcycwh_to_xywh(boxes)
84+
85+
local out = {
86+
img = img,
87+
boxes = boxes_xywh,
88+
scores = scores,
89+
captions = captions,
90+
}
91+
return out
92+
end
93+
94+
function result_to_json(result)
95+
local out = {}
96+
out.boxes = result.boxes:float():totable()
97+
out.scores = result.scores:float():view(-1):totable()
98+
out.captions = result.captions
99+
return out
100+
end
101+
102+
function lua_render_result(result, opt)
103+
-- use lua utilities to render results onto the image (without going)
104+
-- through the vis utilities written in JS/HTML. Kind of ugly output.
105+
106+
-- respect the num_to_draw setting and slice the results appropriately
107+
local boxes = result.boxes
108+
local num_boxes = math.min(opt.num_to_draw, boxes:size(1))
109+
boxes = boxes[{{1, num_boxes}}]
110+
local captions_sliced = {}
111+
for i = 1, num_boxes do
112+
table.insert(captions_sliced, result.captions[i])
113+
end
114+
115+
-- Convert boxes and draw output image
116+
local draw_opt = { text_size = opt.text_size, box_width = opt.box_width }
117+
local img_out = vis_utils.densecap_draw(result.img, boxes, captions_sliced, draw_opt)
118+
return img_out
119+
end
120+
121+
function get_input_images(opt)
122+
-- utility function that figures out which images we should process
123+
-- and fetches all the raw image paths
124+
local image_paths = {}
125+
if opt.input_image ~= '' then
126+
table.insert(image_paths, opt.input_image)
127+
elseif opt.input_dir ~= '' then
128+
-- iterate all files in input directory and add them to work
129+
for fn in paths.files(opt.input_dir) do
130+
if string.sub(fn, 1, 1) ~= '.' then
131+
local img_in_path = paths.concat(opt.input_dir, fn)
132+
table.insert(image_paths, img_in_path)
133+
end
134+
end
135+
elseif opt.input_split ~= '' then
136+
-- load json information that contains the splits information for VG
137+
local info = utils.read_json(opt.splits_json)
138+
local split_img_ids = info[opt.input_split] -- is a table of integer ids
139+
for k=1,#split_img_ids do
140+
local img_in_path = paths.concat(opt.vg_img_root_dir, tostring(split_img_ids[k]) .. '.jpg')
141+
table.insert(image_paths, img_in_path)
142+
end
143+
else
144+
error('one of input_image, input_dir, or input_split must be provided.')
145+
end
146+
return image_paths
147+
end
148+
149+
function process(model, image_paths, opt, idx, dtype)
150+
-- get paths to all images we should be evaluating
151+
152+
local num_process = #image_paths
153+
local results_json = {}
154+
155+
-- create local vars for status and result because it cannot be done at
156+
-- the same time as an unpack
157+
local result = nil
158+
local status = nil
159+
160+
local start = os.time()
161+
162+
for k=1,num_process do
163+
local img_path = image_paths[k]
164+
print(string.format('%d/%d processing image %s', k, num_process, img_path))
165+
-- run the model on the image and obtain results
166+
-- pcall is a protected call. Will return the function status and funciton output
167+
-- This helps protect long processes from corrupt or irregular input images
168+
status, result = pcall(run_image, model, img_path, opt, dtype)
169+
170+
-- check for a good status
171+
if status==true then
172+
-- handle output serialization: either to directory or for pretty html vis
173+
if opt.output_dir ~= '' then
174+
-- This will do the drawing. Don't need if to check model performance right now
175+
local img_out = lua_render_result(result, opt)
176+
local img_out_path = paths.concat(opt.output_dir, paths.basename(img_path))
177+
image.save(img_out_path, img_out)
178+
end
179+
if opt.output_vis == 1 then
180+
-- I don't need to copy the image. I want to push things through!!
181+
182+
if opt.save_processed_img == 1 then
183+
-- save the raw image to vis/data/
184+
local img_out_path = paths.concat(opt.output_vis_dir, paths.basename(img_path))
185+
image.save(img_out_path, result.img)
186+
end
187+
188+
-- keep track of the (thin) json information with all result metadata
189+
local result_json = result_to_json(result)
190+
result_json.img_name = paths.basename(img_path)
191+
table.insert(results_json, result_json)
192+
end
193+
194+
-- some nice status updates
195+
local total_elapsed = os.time() - start
196+
local percent_complete = k / #image_paths * 100
197+
local time_per_image = total_elapsed / k
198+
print(string.format("%0.2f %% complete in %i seconds (%0.2f per image)\n",
199+
percent_complete, total_elapsed, time_per_image))
200+
else
201+
print(string.format("Whoops! Couldn't use %s\n", img_path))
202+
end
203+
end
204+
205+
if #results_json > 0 then
206+
-- serialize to json
207+
local out = {}
208+
out.results = results_json
209+
out.opt = opt
210+
utils.write_json(paths.concat(opt.output_vis_dir, string.format('results_%i.json', idx)), out)
211+
end
212+
end
213+
214+
--[[
215+
MAIN LOOP
216+
--]]
217+
218+
-- Load the model, and cast to the right type
219+
local dtype, use_cudnn = utils.setup_gpus(opt.gpu, opt.use_cudnn)
220+
local checkpoint = torch.load(opt.checkpoint)
221+
local model = checkpoint.model
222+
model:convert(dtype, use_cudnn)
223+
model:setTestArgs{
224+
rpn_nms_thresh = opt.rpn_nms_thresh,
225+
final_nms_thresh = opt.final_nms_thresh,
226+
num_proposals = opt.num_proposals,
227+
}
228+
model:evaluate()
229+
230+
if opt.input_batch_dir ~= '' then
231+
-- this will hold all of the text file paths
232+
local batches = {}
233+
-- list all items in the directory
234+
for file in paths.files(opt.input_batch_dir) do
235+
-- check to make sure the file doesn't start with a period
236+
if string.sub(file, 1, 1) ~= '.' then
237+
-- make fully qualified path
238+
local text_file = paths.concat(opt.input_batch_dir, file)
239+
table.insert(batches, text_file)
240+
end
241+
end
242+
print(string.format("Found %i batch files", #batches))
243+
-- loop over batches. We do this separate from collecting file names
244+
-- because paths.files is an iterator
245+
for idx, file in pairs(batches) do
246+
-- collect image paths
247+
local image_paths = {}
248+
for line in io.lines(file) do
249+
table.insert(image_paths, line)
250+
end
251+
process(model, image_paths, opt, idx, dtype)
252+
end
253+
else
254+
local image_paths = get_input_images(opt)
255+
-- index will just be 1
256+
process(model, image_paths, opt, 1, dtype)
257+
end

0 commit comments

Comments
 (0)