Skip to content

我尝试使用Time-llm在UEA的SRS1 2分类进行尝试,但训练失败了 #190

@Sample-design-alt

Description

@Sample-design-alt

我尝试改写你的代码在分类问题上,但是结果只有50%。这是为什么呢?我也尝试把bf16改为了float32,还有哪些可能存在的问题导致我的模型无法训练?(还有其他几个数据集也存在同样的问题)

这是我的分类forward:
` def classification(self, x_enc):

        x_enc = self.normalize_layers(x_enc, 'norm')

        B, T, N = x_enc.size()
        x_enc = x_enc.permute(0, 2, 1).contiguous().reshape(B * N, T, 1)

        # 2. 计算统计特征 (保持不变,这些特征对分类也很有用)
        min_values = torch.min(x_enc, dim=1)[0]
        max_values = torch.max(x_enc, dim=1)[0]
        medians = torch.median(x_enc, dim=1).values
        lags = self.calcute_lags(x_enc)
        trends = x_enc.diff(dim=1).sum(dim=1)

        # 3. 修改 Prompt (重点修改部分)
        prompt = []
        for b in range(x_enc.shape[0]):
            min_values_str = str(min_values[b].tolist()[0])
            max_values_str = str(max_values[b].tolist()[0])
            median_values_str = str(medians[b].tolist()[0])
            lags_values_str = str(lags[b].tolist())
            
            # --- 修改点:提示词改为分类任务 ---
            prompt_ = (
                f"<|start_prompt|>Dataset description: {self.description}"
                # 这里的 class_names 需要你在初始化时定义,或者直接写死类别数量
                f"Task description: Classify the input time series into one of the {self.num_classes} classes; " 
                "Input statistics: "
                f"min value {min_values_str}, "
                f"max value {max_values_str}, "
                f"median value {median_values_str}, "
                f"the trend of input is {'upward' if trends[b] > 0 else 'downward'}, "
                f"top 5 lags are : {lags_values_str}<|<end_prompt>|>"
            )
            prompt.append(prompt_)

        x_enc = x_enc.reshape(B, N, T).permute(0, 2, 1).contiguous()

        # 4. LLM 编码过程 (基本保持不变)
        prompt = self.tokenizer(prompt, return_tensors="pt", padding=True, truncation=True, max_length=2048).input_ids
        prompt_embeddings = self.llm_model.get_input_embeddings()(prompt.to(x_enc.device))

        source_embeddings = self.mapping_layer(self.word_embeddings.permute(1, 0)).permute(1, 0)

        x_enc = x_enc.permute(0, 2, 1).contiguous()
        enc_out, n_vars = self.patch_embedding(x_enc)
        enc_out = self.reprogramming_layer(enc_out, source_embeddings, source_embeddings)
        llama_enc_out = torch.cat([prompt_embeddings, enc_out], dim=1)
        
        # 获取 LLM 输出
        dec_out = self.llm_model(inputs_embeds=llama_enc_out).last_hidden_state

        # ts_out = dec_out[:, -self.patch_nums:, :]
        # cls_token_out = ts_out.reshape(B, self.patch_nums, -1).mean(dim=1)
        # outputs = self.output_projection(cls_token_out)

        dec_out = dec_out[:,0,:]
        # cls_token_out = dec_out.mean(dim=1)
        cls_token_out = dec_out.reshape(B, N, -1).contiguous()
        # cls_token_out = cls_token_out.reshape(B, -1)

        outputs = self.output_projection(cls_token_out)
        `

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions