From deb4d3d1767f2daf9f8a624e1f14cf3c6b5a95e8 Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 6 Mar 2026 15:32:24 +0800 Subject: [PATCH 1/3] fix --- cookbook/client/twinkle/self_congnition.py | 3 +- .../model/transformers/transformers.py | 32 +++++++++++++++++++ src/twinkle/server/launcher.py | 2 +- src/twinkle/server/twinkle/model.py | 4 +-- 4 files changed, 37 insertions(+), 4 deletions(-) diff --git a/cookbook/client/twinkle/self_congnition.py b/cookbook/client/twinkle/self_congnition.py index dc318dcf..f9e56dd1 100644 --- a/cookbook/client/twinkle/self_congnition.py +++ b/cookbook/client/twinkle/self_congnition.py @@ -102,10 +102,11 @@ def train(): for step, batch in enumerate(dataloader): # Forward pass + backward pass (computes gradients) output = model.forward_backward(inputs=batch) + loss=output.get('loss', 'N/A') # Log the loss every 2 steps (aligned with gradient accumulation) if step % 2 == 0: - logger.info(f'Current is step {step // 2}, loss: {output}') + logger.info(f'Current is step {step // 2}, loss: {loss}') # Clip gradients to prevent exploding gradients (max norm = 1.0) model.clip_grad_norm(1.0) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index e6fbc397..af83d517 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -343,6 +343,31 @@ def _construct_default_optimizer_group(self): _device_mesh=self.device_mesh, ) + @staticmethod + def _to_cpu_safe_output(obj: Any) -> Any: + """Convert nested outputs into CPU-safe Python objects for HTTP transport.""" + import numpy as np + from collections.abc import Mapping + + if isinstance(obj, torch.Tensor): + tensor = torch_util.to_local_tensor(obj).detach().cpu() + if tensor.numel() == 1: + return tensor.item() + return tensor.tolist() + if isinstance(obj, np.ndarray): + if obj.size == 1: + return obj.item() + return obj.tolist() + if isinstance(obj, np.generic): + return obj.item() + if isinstance(obj, Mapping): + return {key: TransformersModel._to_cpu_safe_output(value) for key, value in obj.items()} + if isinstance(obj, tuple): + return [TransformersModel._to_cpu_safe_output(value) for value in obj] + if isinstance(obj, list): + return [TransformersModel._to_cpu_safe_output(value) for value in obj] + return obj + @remote_function() def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs): """Call forward function and record the inputs and outputs. @@ -515,6 +540,13 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr self.backward(**kwargs) return outputs + @remote_function(dispatch='slice_dp', collect='mean') + def forward_backward_http(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], + **kwargs): + """HTTP-safe forward/backward that materializes outputs before they leave the worker.""" + outputs = self.forward_backward(inputs=inputs, **kwargs) + return self._to_cpu_safe_output(outputs) + @remote_function() def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs): """ Clip the gradient norm diff --git a/src/twinkle/server/launcher.py b/src/twinkle/server/launcher.py index 334cd99d..843418c2 100644 --- a/src/twinkle/server/launcher.py +++ b/src/twinkle/server/launcher.py @@ -216,7 +216,7 @@ def _deploy_application(self, app_config: dict[str, Any]) -> None: # Pass http_options to server apps for internal proxy routing http_options = self.config.get('http_options', {}) - if http_options: + if import_path == 'server' and http_options: args['http_options'] = http_options # Build and deploy the application diff --git a/src/twinkle/server/twinkle/model.py b/src/twinkle/server/twinkle/model.py index 858d0716..72384ba8 100644 --- a/src/twinkle/server/twinkle/model.py +++ b/src/twinkle/server/twinkle/model.py @@ -295,8 +295,8 @@ def forward_backward(self, request: Request, body: ForwardRequest): else: assert isinstance(inputs, dict) inputs = InputFeature(**inputs) if 'input_ids' in inputs else Trajectory(**inputs) - ret = self.model.forward_backward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) - return {'result': str(ret)} + ret = self.model.forward_backward_http(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) + return {'result': ret} @app.post('/get_train_configs') def get_train_configs(self, request: Request, body: AdapterRequest): From c0d747ed557b5131d8162dec2330e7701cdd422f Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 6 Mar 2026 15:42:40 +0800 Subject: [PATCH 2/3] fix --- src/twinkle/model/transformers/transformers.py | 4 +--- 1 file changed, 1 insertion(+), 3 deletions(-) diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index af83d517..062258de 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -362,9 +362,7 @@ def _to_cpu_safe_output(obj: Any) -> Any: return obj.item() if isinstance(obj, Mapping): return {key: TransformersModel._to_cpu_safe_output(value) for key, value in obj.items()} - if isinstance(obj, tuple): - return [TransformersModel._to_cpu_safe_output(value) for value in obj] - if isinstance(obj, list): + if isinstance(obj, (list, tuple)): return [TransformersModel._to_cpu_safe_output(value) for value in obj] return obj From 1c2407de33a1db6eedabcb4debbef99febb59ebf Mon Sep 17 00:00:00 2001 From: weikaiwen Date: Fri, 6 Mar 2026 16:40:13 +0800 Subject: [PATCH 3/3] fix --- .../model/transformers/transformers.py | 36 ++------------- .../twinkle/common/transformers_model.py | 46 +++++++++++++++++++ src/twinkle/server/twinkle/model.py | 6 +-- 3 files changed, 52 insertions(+), 36 deletions(-) create mode 100644 src/twinkle/server/twinkle/common/transformers_model.py diff --git a/src/twinkle/model/transformers/transformers.py b/src/twinkle/model/transformers/transformers.py index 062258de..0d92bd24 100644 --- a/src/twinkle/model/transformers/transformers.py +++ b/src/twinkle/model/transformers/transformers.py @@ -343,29 +343,6 @@ def _construct_default_optimizer_group(self): _device_mesh=self.device_mesh, ) - @staticmethod - def _to_cpu_safe_output(obj: Any) -> Any: - """Convert nested outputs into CPU-safe Python objects for HTTP transport.""" - import numpy as np - from collections.abc import Mapping - - if isinstance(obj, torch.Tensor): - tensor = torch_util.to_local_tensor(obj).detach().cpu() - if tensor.numel() == 1: - return tensor.item() - return tensor.tolist() - if isinstance(obj, np.ndarray): - if obj.size == 1: - return obj.item() - return obj.tolist() - if isinstance(obj, np.generic): - return obj.item() - if isinstance(obj, Mapping): - return {key: TransformersModel._to_cpu_safe_output(value) for key, value in obj.items()} - if isinstance(obj, (list, tuple)): - return [TransformersModel._to_cpu_safe_output(value) for value in obj] - return obj - @remote_function() def forward(self, *, inputs: Union[InputFeature, List[InputFeature], List[Trajectory]], **kwargs): """Call forward function and record the inputs and outputs. @@ -538,13 +515,6 @@ def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Tr self.backward(**kwargs) return outputs - @remote_function(dispatch='slice_dp', collect='mean') - def forward_backward_http(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], - **kwargs): - """HTTP-safe forward/backward that materializes outputs before they leave the worker.""" - outputs = self.forward_backward(inputs=inputs, **kwargs) - return self._to_cpu_safe_output(outputs) - @remote_function() def clip_grad_norm(self, max_grad_norm: float = 1.0, norm_type=2, **kwargs): """ Clip the gradient norm @@ -1138,7 +1108,7 @@ def get_train_configs(self, **kwargs) -> str: return expr # ========================================================================= - # Checkpoint Engine — Weight Sync (from CheckpointEngineMixin) + # Checkpoint Engine weight sync (from CheckpointEngineMixin) # ========================================================================= # prepare_checkpoint_engine, init_checkpoint_process_group, and # finalize_checkpoint_engine are inherited from CheckpointEngineMixin. @@ -1175,7 +1145,7 @@ def weight_generator(): if isinstance(model, PeftModel): model.unmerge_adapter() else: - # ── LoRA-only mode: send only adapter weights ──────────────── + # LoRA-only mode: send only adapter weights. # Use PEFT's get_peft_model_state_dict for clean LoRA extraction from peft.utils import get_peft_model_state_dict lora_state_dict = get_peft_model_state_dict(model, adapter_name=adapter_name) @@ -1186,7 +1156,7 @@ def weight_generator(): yield name, tensor else: - # ── Full model mode: send all weights (base model sync) ────── + # Full model mode: send all weights (base model sync). state_dict = model.state_dict() def weight_generator(): diff --git a/src/twinkle/server/twinkle/common/transformers_model.py b/src/twinkle/server/twinkle/common/transformers_model.py new file mode 100644 index 00000000..04de9d8c --- /dev/null +++ b/src/twinkle/server/twinkle/common/transformers_model.py @@ -0,0 +1,46 @@ +# Copyright (c) ModelScope Contributors. All rights reserved. +import numpy as np +import torch +from collections.abc import Mapping +from typing import Any, List, Union + +from twinkle import remote_class, remote_function +from twinkle.data_format import InputFeature, Trajectory +from twinkle.infra import _collect_func +from twinkle.model import MultiLoraTransformersModel + + +def collect_forward_backward_http_output(result, device_mesh=None): + aggregated = _collect_func('mean', result, device_mesh=device_mesh) + return TwinkleCompatTransformersModel._to_cpu_safe_output(aggregated) + + +@remote_class() +class TwinkleCompatTransformersModel(MultiLoraTransformersModel): + + @staticmethod + def _to_cpu_safe_output(obj: Any) -> Any: + """Convert nested outputs into CPU-safe Python objects for HTTP transport.""" + from twinkle.utils import torch_util + + if isinstance(obj, torch.Tensor): + tensor = torch_util.to_local_tensor(obj).detach().cpu() + if tensor.numel() == 1: + return tensor.item() + return tensor.tolist() + if isinstance(obj, np.ndarray): + if obj.size == 1: + return obj.item() + return obj.tolist() + if isinstance(obj, np.generic): + return obj.item() + if isinstance(obj, Mapping): + return {key: TwinkleCompatTransformersModel._to_cpu_safe_output(value) for key, value in obj.items()} + if isinstance(obj, (list, tuple)): + return [TwinkleCompatTransformersModel._to_cpu_safe_output(value) for value in obj] + return obj + + @remote_function(dispatch='slice_dp', collect=collect_forward_backward_http_output) + def forward_backward(self, *, inputs: Union[InputFeature, List[InputFeature], Trajectory, List[Trajectory]], + **kwargs): + return super().forward_backward(inputs=inputs, **kwargs) diff --git a/src/twinkle/server/twinkle/model.py b/src/twinkle/server/twinkle/model.py index 72384ba8..4bf4bf4b 100644 --- a/src/twinkle/server/twinkle/model.py +++ b/src/twinkle/server/twinkle/model.py @@ -182,8 +182,8 @@ def __init__(self, nproc_per_node: int, device_group: Dict[str, Any], device_mes instance_id=replica_id, **kwargs) else: - from twinkle.model import MultiLoraTransformersModel - self.model = MultiLoraTransformersModel( + from .common.transformers_model import TwinkleCompatTransformersModel + self.model = TwinkleCompatTransformersModel( model_id=model_id, device_mesh=self.device_mesh, remote_group=self.device_group.name, @@ -295,7 +295,7 @@ def forward_backward(self, request: Request, body: ForwardRequest): else: assert isinstance(inputs, dict) inputs = InputFeature(**inputs) if 'input_ids' in inputs else Trajectory(**inputs) - ret = self.model.forward_backward_http(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) + ret = self.model.forward_backward(inputs=inputs, adapter_name=adapter_name, **extra_kwargs) return {'result': ret} @app.post('/get_train_configs')