Skip to content

Question about input_token_logprobs #2

@WissamAntoun

Description

@WissamAntoun

Hey, thanks a lot for the great work and for releasing the codebase.

I've been trying to create my own annotation on my own dataset using prompt_classify.py but i've been getting errors here:

https://github.com/CodeCreator/WebOrganizer/blob/main/define_domains/prompt_classify.py#L132

assert all(
                len(answer_tokens) > 1
                for answer_tokens in meta_info["input_token_logprobs"]
            ), f"All answers should have at least 2 tokens in {meta_info['input_token_logprobs']}"

The thing is the SGLang classifier has been returning the answer/choice token as the first token. Like the following:

input_token_logprobs = [
[[-4.267991065979004, 32, 'A']],
[[-4.017991065979004, 33, 'B']],
[[-4.517991065979004, 34, 'C']],
[[-4.392991065979004, 35, 'D']],
[[-3.142991065979004, 36, 'E']],
[[-4.517991065979004, 37, 'F']],
[[-3.017991065979004, 38, 'G']],
[[-3.142991065979004, 39, 'H']],
[[-2.392991065979004, 40, 'I']],
[[-3.5137503147125244, 41, 'J']],
[[-3.2658731937408447, 42, 'K']],
[[-3.2658731937408447, 43, 'L']],
[[-2.3887503147125244, 44, 'M']],
[[-3.8887503147125244, 45, 'N']],
[[-4.013750076293945, 46, 'O']],
[[-3.6387503147125244, 47, 'P']],
[[-2.7637503147125244, 48, 'Q']],
[[-2.6387503147125244, 49, 'R']],
[[-2.5137503147125244, 50, 'S']],
[[-2.3887503147125244, 51, 'T']],
[[-4.138750076293945, 52, 'U']],
[[-3.8887503147125244, 53, 'V']],
[[-2.8887503147125244, 54, 'W']],
[[-3.2637503147125244, 55, 'X']]
]

I did some digging in SgLang's code and i found that in their runtime, they actualy remove the first non choice token and recompute the logprobs. Check the code here: https://github.com/sgl-project/sglang/blob/main/python/sglang/lang/backend/runtime_endpoint.py#L267 .

        # Remove extra token if no token healing occurred
        for i in range(len(input_token_logprobs)):
            healed_token_str = input_token_logprobs[i][0][-1]
            if s.text_.endswith(healed_token_str):
                healed_token_logprob = input_token_logprobs[i][0][0]
                normalized_prompt_logprobs[i] = (
                    normalized_prompt_logprobs[i] * len(input_token_logprobs[i])
                    - healed_token_logprob
                ) / (len(input_token_logprobs[i]) - 1)
                input_token_logprobs[i] = input_token_logprobs[i][1:]

Could it be that this they have already address the issue you are accounting for?

Metadata

Metadata

Assignees

No one assigned

    Labels

    No labels
    No labels

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions