Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
16 changes: 8 additions & 8 deletions src/integrationtest/python/system_tests.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
import os
import tempfile
import sh
import subprocess
import json
import atexit
import shutil
Expand Down Expand Up @@ -59,14 +59,14 @@ def check_generated_files():
source_frames.write(b'\x00' * 30000)
source_frames.write(b'\xff' * 30000)

sh.ffmpeg('-t', '10',
subprocess.call(['ffmpeg','-t', '10',
'-s', '100x100',
'-f', 'rawvideo',
'-pix_fmt', 'rgb24',
'-r', '8',
'-i', './source',
'source.mp4',
_cwd=temp_dir,
'source.mp4']
cwd=temp_dir,
)

json_content = []
Expand All @@ -86,11 +86,11 @@ def check_generated_files():
# step 2: run the gulping

# PATH=src/main/scripts:$PATH PYTHONPATH=src/main/python
command = sh.gulp_20bn_json_videos(
command = subprocess.call(['gulp_20bn_json_videos',
'--videos_per_chunk', '10',
os.path.join(temp_dir, 'videos.json'),
temp_dir,
output_dir,
output_dir]
)

# step 3: sanity check the output
Expand Down Expand Up @@ -123,11 +123,11 @@ def check_generated_files():

# step 6: extend the existing gulps

command = sh.gulp_20bn_json_videos(
command = subprocess.call(['gulp_20bn_json_videos',
'--videos_per_chunk', '10',
os.path.join(temp_dir, 'videos_extend.json'),
temp_dir,
output_dir,
output_dir]
)

# step 7: sanity check the extended output
Expand Down
12 changes: 7 additions & 5 deletions src/main/python/gulpio/dataset.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,16 +17,16 @@ def __init__(self, data_path, num_frames, step_size,

Args:
data_path (str): path to GulpIO dataset folder
label_path (str): path to GulpIO label dictionary matching
label ids to label names
num_frames (int): number of frames to be fetched.
step_size (int): number of frames skippid while picking
sequence of frames from each video.
is_va (bool): sets the necessary augmention procedure.
is_val (bool): sets the necessary augmention procedure.
transform (object): set of augmentation steps defined by
Compose(). Default is None.
target_transform (func): performs preprocessing on labels if
defined. Default is None.
target_transform (func): a transformation function applied to each
single target, where target is the id assigned to a label. The
mapping from label to id is provided in the `label_idx` member-
variable. Default is None.
stack (bool): stack frames into a numpy.array. Default is True.
random_offset (bool): random offsetting to pick frames, if
number of frames are more than what is necessary.
Expand Down Expand Up @@ -88,6 +88,8 @@ def __getitem__(self, index):
# augmentation
if self.transform_video:
frames = self.transform_video(frames)
if self.target_transform:
target_idx = self.target_transform(target_idx)
# format data to torch tensor
if self.stack:
frames = np.stack(frames)
Expand Down
7 changes: 7 additions & 0 deletions src/unittest/python/loader_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -84,6 +84,13 @@ def test_dataset(self):
False, stack=False)
self.iterate(loader)

def test_target_transform(self):
self.create_chunk()
target_label = 7
dataset = GulpVideoDataset(self.temp_dir, 2, 2,
False, target_transform=lambda y: target_label)
assert dataset[0][1] == target_label


class TestGulpImageDataset(unittest.TestCase):

Expand Down