Skip to content

Return LM Filter Probabilities#125

Open
sidjha1 wants to merge 8 commits intomainfrom
sid/filter-probs
Open

Return LM Filter Probabilities#125
sidjha1 wants to merge 8 commits intomainfrom
sid/filter-probs

Conversation

@sidjha1
Copy link
Copy Markdown
Collaborator

@sidjha1 sidjha1 commented Feb 20, 2025

Introduces return_probs. Output can look like

                   Course Name Department  Level  filter_label  probs_filter
0  Introduction to Programming         CS    100         False      0.001927
1         Advanced Programming         CS    200         False      0.000024
2               Cooking Basics   Culinary    100          True      0.880797
3       Advanced Culinary Arts   Culinary    200          True      0.851953
4              Data Structures         CS    300         False      0.000261
5                   Algorithms         CS    300         False      0.000553
6               French Cuisine   Culinary    200          True      0.985936
7              Italian Cooking   Culinary    200          True      0.997527

Or

              Course Name Department  Level  probs_filter
2          Cooking Basics   Culinary    100      0.851953
3  Advanced Culinary Arts   Culinary    200      0.851953
6          French Cuisine   Culinary    200      0.985936
7         Italian Cooking   Culinary    200      0.999665

Depending on if return_all is set.

This PR also fixes a bug in the way that logprobs were analyzed for the True/False prob calculations.

@sidjha1 sidjha1 requested a review from liana313 February 20, 2025 04:47
Comment thread tests/test_filter.py Outdated
Comment on lines -9 to -28
return pd.DataFrame({
"Course Name": [
"Introduction to Programming",
"Advanced Programming",
"Cooking Basics",
"Advanced Culinary Arts",
"Data Structures",
"Algorithms",
"French Cuisine",
"Italian Cooking"
],
"Department": [
"CS", "CS", "Culinary", "Culinary",
"CS", "CS", "Culinary", "Culinary"
],
"Level": [
100, 200, 100, 200,
300, 300, 200, 200
]
})
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

This file contains a bunch of ruff formatting changes. It's that weird thing where the CI does not work the same way as local. In any case this formatting is better.

Comment thread tests/test_filter.py Outdated
Comment on lines +106 to +127
class TestFilterWithProbs(BaseTest):
def test_filter_with_probs(self, sample_df):
"""Test semantic filter with probabilities returned to the user"""
lm = LM(model="gpt-4o-mini")
lotus.settings.configure(lm=lm)
result = sample_df.sem_filter("{Course Name} will be fun", return_probs=True)
print(result)
assert "probs_filter" in result.columns

def test_filter_with_probs_and_return_all(self, sample_df):
"""Test semantic filter with probabilities returned to the user"""
lm = LM(model="gpt-4o-mini")
lotus.settings.configure(lm=lm)
result = sample_df.sem_filter("{Course Name} will be fun", return_probs=True, return_all=True)
print(result)
assert "probs_filter" in result.columns

for idx, row in result.iterrows():
if row["filter_label"]:
assert row["probs_filter"] > 0.5
else:
assert row["probs_filter"] <= 0.5
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Core tests are here

vs = FaissVS()

lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini, rm=rm)
lotus.settings.configure(lm=gpt_4o, helper_lm=gpt_4o_mini, rm=rm, vs=vs)
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Not related to my logic changes. But vs needs to be set here for the example to run, given the recent merge.

Comment thread lotus/models/lm.py Outdated
Comment on lines +196 to +197
cleaned_token = logprob.token.lower().strip()
if cleaned_token not in ["true", "false"]:
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Needed to add this .lower().strip() cleaning in a few places so we can process True, True true, etc.

Copy link
Copy Markdown
Collaborator

@liana313 liana313 left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

For the most part looks good -- I left a couple of comments about naming. Once those are updated, we can also update the filter docs section

Comment thread lotus/sem_ops/sem_filter.py Outdated
new_df["raw_output" + suffix] = filtered_raw_outputs

if return_scores:
new_df["score"] = filtered_scores
Copy link
Copy Markdown
Collaborator

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

we should add an index to the end of the output col labels, since if a user filters twice, there will be a naming conflict. we should also add a test for filtering twice

Comment thread lotus/sem_ops/sem_filter.py Outdated
@sidjha1 sidjha1 requested a review from liana313 April 1, 2025 16:13
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

None yet

Projects

None yet

Development

Successfully merging this pull request may close these issues.

2 participants