diff --git a/cookbook/client/twinkle/self_congnition.py b/cookbook/client/twinkle/self_congnition.py index dc318dcf..f9e56dd1 100644 --- a/cookbook/client/twinkle/self_congnition.py +++ b/cookbook/client/twinkle/self_congnition.py @@ -102,10 +102,11 @@ def train(): for step, batch in enumerate(dataloader): # Forward pass + backward pass (computes gradients) output = model.forward_backward(inputs=batch) + loss=output.get('loss', 'N/A') # Log the loss every 2 steps (aligned with gradient accumulation) if step % 2 == 0: - logger.info(f'Current is step {step // 2}, loss: {output}') + logger.info(f'Current is step {step // 2}, loss: {loss}') # Clip gradients to prevent exploding gradients (max norm = 1.0) model.clip_grad_norm(1.0) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index e6fbc397..0d92bd24 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -1108,7 +1108,7 @@ def get_train_configs(self, **kwargs) -> str: return expr # ========================================================================= - # Checkpoint Engine — Weight Sync (from CheckpointEngineMixin) + # Checkpoint Engine weight sync (from CheckpointEngineMixin) # ========================================================================= # prepare_checkpoint_engine, init_checkpoint_process_group, and # finalize_checkpoint_engine are inherited from CheckpointEngineMixin. @@ -1145,7 +1145,7 @@ def weight_generator(): if isinstance(model, PeftModel): model.unmerge_adapter() else: - # ── LoRA-only mode: send only adapter weights ──────────────── + # LoRA-only mode: send only adapter weights. # Use PEFT's get_peft_model_state_dict for clean LoRA extraction from peft.utils import get_peft_model_state_dict lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name) @@ -1156,7 +1156,7 @@ def weight_generator(): yield name, tensor else: - # ── Full model mode: send all weights (base model sync) ────── + # Full model mode: send all weights (base model sync). state_dict = model.state_dict() def weight_generator(): diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py index 334cd99d..843418c2 100644 --- a/src/twinkle/server/launcher.py +++ b/src/twinkle/server/launcher.py @@ -216,7 +216,7 @@ def _deploy_application(self, app_config: dict[str, Any]) -> None: # Pass http_options to server apps for internal proxy routing http_options = self.config.get('http_options', {}) - if http_options: + if import_path == 'server' and http_options: args['http_options'] = http_options # Build and deploy the application diff --git a/src/twinkle/server/twinkle/common/transformers_model.py b/src/twinkle/server/twinkle/common/transformers_model.py new file mode 100644 index 00000000..04de9d8c --- /dev/null +++ b/src/twinkle/server/twinkle/common/transformers_model.py @@ -0,0 +1,46 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import numpy as np +import torch +from collections.abc import Mapping +from typing import Any, List, Union + +from twinkle import remote_class, remote_function +from twinkle.data_format import InputFeature, Trajectory +from twinkle.infra import _collect_func +from twinkle.model import MultiLoraTransformersModel + + +def collect_forward_backward_http_output(result, device_mesh=None): + aggregated = _collect_func('mean', result, device_mesh=device_mesh) + return TwinkleCompatTransformersModel._to_cpu_safe_output(aggregated) + + +@remote_class() +class TwinkleCompatTransformersModel(MultiLoraTransformersModel): + + @staticmethod + def _to_cpu_safe_output(obj: Any) -> Any: + """Convert nested outputs into CPU-safe Python objects for HTTP transport.""" + from twinkle.utils import torch_util + + if isinstance(obj, torch.Tensor): + tensor = torch_util.to_local_tensor(obj).detach().cpu() + if tensor.numel() == 1: + return tensor.item() + return tensor.tolist() + if isinstance(obj, np.ndarray): + if obj.size == 1: + return obj.item() + return obj.tolist() + if isinstance(obj, np.generic): + return obj.item() + if isinstance(obj, Mapping): + return {key: TwinkleCompatTransformersModel._to_cpu_safe_output(value) for key, value in obj.items()} + if isinstance(obj, (list, tuple)): + return [TwinkleCompatTransformersModel._to_cpu_safe_output(value) for value in obj] + return obj + + @remote_function(dispatch='slice_dp', collect=collect_forward_backward_http_output) + def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], + **kwargs): + return super().forward_backward(inputs=inputs, **kwargs) diff --git a/src/twinkle/server/twinkle/model.py b/src/twinkle/server/twinkle/model.py index 858d0716..4bf4bf4b 100644 --- a/src/twinkle/server/twinkle/model.py +++ b/src/twinkle/server/twinkle/model.py @@ -182,8 +182,8 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], device_mes instance_id=replica_id, **kwargs) else: - from twinkle.model import MultiLoraTransformersModel - self.model = MultiLoraTransformersModel( + from .common.transformers_model import TwinkleCompatTransformersModel + self.model = TwinkleCompatTransformersModel( model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, @@ -296,7 +296,7 @@ def forward_backward(self, request: Request, body: ForwardRequest): assert isinstance(inputs, dict) inputs = InputFeature(**inputs) if 'input_ids' in inputs else Trajectory(**inputs) ret = self.model.forward_backward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) - return {'result': str(ret)} + return {'result': ret} @app.post('/get_train_configs') def get_train_configs(self, request: Request, body: AdapterRequest):