Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
19 changes: 15 additions & 4 deletions src/openworldlib/operators/hunyuan_worldplay_operator.py
Original file line number Diff line number Diff line change
Expand Up @@ -15,10 +15,21 @@


class HunyuanWorldPlayOperator(BaseOperator):
def __init__(self, operation_types=None, interaction_template=None):
def __init__(
self,
operation_types=None,
interaction_template=None,
*,
forward_speed: float = 0.08,
yaw_speed_deg: float = 3.0,
pitch_speed_deg: float = 3.0,
):
if operation_types is None:
operation_types = ["action_instruction"]
super().__init__(operation_types=operation_types)
self.forward_speed = forward_speed
self.yaw_speed_deg = yaw_speed_deg
self.pitch_speed_deg = pitch_speed_deg
self.interaction_template = interaction_template or [
"forward",
"backward",
Expand Down Expand Up @@ -117,9 +128,9 @@ def _is_action_sequence(self, interaction) -> bool:
return all(item in self.interaction_template for item in interaction)

def _actions_to_pose_json(self, actions: list[str]) -> dict:
forward_speed = 0.08
yaw_speed = np.deg2rad(3)
pitch_speed = np.deg2rad(3)
forward_speed = self.forward_speed
yaw_speed = np.deg2rad(self.yaw_speed_deg)
pitch_speed = np.deg2rad(self.pitch_speed_deg)
motions: list[dict] = []
for action in actions:
move = {}
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -38,6 +38,9 @@ def from_pretrained(
overlap_group_offloading: bool = True,
init_infer_state: bool = True,
infer_state_kwargs: Optional[dict] = None,
forward_speed: float = 0.08,
yaw_speed_deg: float = 3.0,
pitch_speed_deg: float = 3.0,
**kwargs
) -> 'HunyuanWorldPlayPipeline':
"""
Expand Down Expand Up @@ -90,7 +93,11 @@ def from_pretrained(
action_ckpt=model_path,
**kwargs
)
operators = HunyuanWorldPlayOperator()
operators = HunyuanWorldPlayOperator(
forward_speed=forward_speed,
yaw_speed_deg=yaw_speed_deg,
pitch_speed_deg=pitch_speed_deg,
)

return cls(
synthesis_model=synthesis_model,
Expand Down Expand Up @@ -165,6 +172,9 @@ def __call__(
model_type: str = "ar",
user_height: Optional[int] = None,
user_width: Optional[int] = None,
forward_speed: Optional[float] = None,
yaw_speed_deg: Optional[float] = None,
pitch_speed_deg: Optional[float] = None,
**kwargs
):
"""
Expand Down Expand Up @@ -194,6 +204,13 @@ def __call__(
Returns:
HunyuanVideoPipelineOutput: 包含生成的视频帧
"""
if forward_speed is not None:
self.operators.forward_speed = forward_speed
if yaw_speed_deg is not None:
self.operators.yaw_speed_deg = yaw_speed_deg
if pitch_speed_deg is not None:
self.operators.pitch_speed_deg = pitch_speed_deg

video_length = num_frames
pose_value = interactions if interactions is not None else pose
if pose_value is None:
Expand Down
3 changes: 3 additions & 0 deletions test/test_hunyuan_worldplay.py
Original file line number Diff line number Diff line change
Expand Up @@ -21,6 +21,9 @@
prompt=prompt,
image_path=image_path,
interactions=interaction_signal,
forward_speed=0.08,
yaw_speed_deg=3.0,
pitch_speed_deg=3.0,
)

save_video_path = os.path.join(output_path, "hunyuan_worldplay_demo.mp4")
Expand Down