Skip to content
Draft
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
3 changes: 2 additions & 1 deletion cookbook/client/twinkle/self_congnition.py
Original file line number Diff line number Diff line change
Expand Up @@ -102,10 +102,11 @@ def train():
for step, batch in enumerate(dataloader):
# Forward pass + backward pass (computes gradients)
output = model.forward_backward(inputs=batch)
loss=output.get('loss', 'N/A')

# Log the loss every 2 steps (aligned with gradient accumulation)
if step % 2 == 0:
logger.info(f'Current is step {step // 2}, loss: {output}')
logger.info(f'Current is step {step // 2}, loss: {loss}')

# Clip gradients to prevent exploding gradients (max norm = 1.0)
model.clip_grad_norm(1.0)
Expand Down
6 changes: 3 additions & 3 deletions src/twinkle/model/transformers/transformers.py
Original file line number Diff line number Diff line change
Expand Up @@ -1108,7 +1108,7 @@ def get_train_configs(self, **kwargs) -> str:
return expr

# =========================================================================
# Checkpoint Engine — Weight Sync (from CheckpointEngineMixin)
# Checkpoint Engine weight sync (from CheckpointEngineMixin)
# =========================================================================
# prepare_checkpoint_engine, init_checkpoint_process_group, and
# finalize_checkpoint_engine are inherited from CheckpointEngineMixin.
Expand Down Expand Up @@ -1145,7 +1145,7 @@ def weight_generator():
if isinstance(model, PeftModel):
model.unmerge_adapter()
else:
# ── LoRA-only mode: send only adapter weights ────────────────
# LoRA-only mode: send only adapter weights.
# Use PEFT's get_peft_model_state_dict for clean LoRA extraction
from peft.utils import get_peft_model_state_dict
lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name)
Expand All @@ -1156,7 +1156,7 @@ def weight_generator():
yield name, tensor

else:
# ── Full model mode: send all weights (base model sync) ──────
# Full model mode: send all weights (base model sync).
state_dict = model.state_dict()

def weight_generator():
Expand Down
2 changes: 1 addition & 1 deletion src/twinkle/server/launcher.py
Original file line number Diff line number Diff line change
Expand Up @@ -216,7 +216,7 @@ def _deploy_application(self, app_config: dict[str, Any]) -> None:

# Pass http_options to server apps for internal proxy routing
http_options = self.config.get('http_options', {})
if http_options:
if import_path == 'server' and http_options:
args['http_options'] = http_options

# Build and deploy the application
Expand Down
46 changes: 46 additions & 0 deletions src/twinkle/server/twinkle/common/transformers_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,46 @@
# Copyright (c) ModelScope Contributors. All rights reserved.
import numpy as np
import torch
from collections.abc import Mapping
from typing import Any, List, Union

from twinkle import remote_class, remote_function
from twinkle.data_format import InputFeature, Trajectory
from twinkle.infra import _collect_func
from twinkle.model import MultiLoraTransformersModel


def collect_forward_backward_http_output(result, device_mesh=None):
aggregated = _collect_func('mean', result, device_mesh=device_mesh)
return TwinkleCompatTransformersModel._to_cpu_safe_output(aggregated)


@remote_class()
class TwinkleCompatTransformersModel(MultiLoraTransformersModel):

@staticmethod
def _to_cpu_safe_output(obj: Any) -> Any:
"""Convert nested outputs into CPU-safe Python objects for HTTP transport."""
from twinkle.utils import torch_util

if isinstance(obj, torch.Tensor):
tensor = torch_util.to_local_tensor(obj).detach().cpu()
if tensor.numel() == 1:
return tensor.item()
return tensor.tolist()
if isinstance(obj, np.ndarray):
if obj.size == 1:
return obj.item()
return obj.tolist()
if isinstance(obj, np.generic):
return obj.item()
if isinstance(obj, Mapping):
return {key: TwinkleCompatTransformersModel._to_cpu_safe_output(value) for key, value in obj.items()}
if isinstance(obj, (list, tuple)):
return [TwinkleCompatTransformersModel._to_cpu_safe_output(value) for value in obj]
return obj

@remote_function(dispatch='slice_dp', collect=collect_forward_backward_http_output)
def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]],
**kwargs):
return super().forward_backward(inputs=inputs, **kwargs)
6 changes: 3 additions & 3 deletions src/twinkle/server/twinkle/model.py
Original file line number Diff line number Diff line change
Expand Up @@ -182,8 +182,8 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], device_mes
instance_id=replica_id,
**kwargs)
else:
from twinkle.model import MultiLoraTransformersModel
self.model = MultiLoraTransformersModel(
from .common.transformers_model import TwinkleCompatTransformersModel
self.model = TwinkleCompatTransformersModel(
model_id=model_id,
device_mesh=self.device_mesh,
remote_group=self.device_group.name,
Expand Down Expand Up @@ -296,7 +296,7 @@ def forward_backward(self, request: Request, body: ForwardRequest):
assert isinstance(inputs, dict)
inputs = InputFeature(**inputs) if 'input_ids' in inputs else Trajectory(**inputs)
ret = self.model.forward_backward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs)
return {'result': str(ret)}
return {'result': ret}

@app.post('/get_train_configs')
def get_train_configs(self, request: Request, body: AdapterRequest):
Expand Down
Loading