Skip to content
Closed
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
34 changes: 23 additions & 11 deletions .github/tests/lm_tests.py
Original file line number Diff line number Diff line change
Expand Up @@ -290,7 +290,7 @@ def test_filter_cascade(setup_models):
def test_join_cascade(setup_models):
models = setup_models
rm = SentenceTransformersRM(model="intfloat/e5-base-v2")
vs = FaissVS()
vs = FaissVS()
lotus.settings.configure(lm=models["gpt-4o-mini"], rm=rm, vs=vs)

data1 = {
Expand Down Expand Up @@ -503,7 +503,7 @@ def test_operator_cache(setup_models, model):
"Chemical Kinetics and Catalysis",
"Transport Phenomena and Separations",
],
"_map": [
"map_output": [
"Process Dynamics and Control",
"Advanced Optimization Techniques in Engineering",
"Reaction Kinetics and Mechanisms",
Expand All @@ -521,9 +521,13 @@ def test_operator_cache(setup_models, model):
second_response = df.sem_map(user_instruction)
assert lm.stats.total_usage.operator_cache_hits == 1

first_response["_map"] = first_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
second_response["_map"] = second_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
expected_response["_map"] = expected_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
first_response["map_output"] = first_response["map_output"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
second_response["map_output"] = (
second_response["map_output"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
)
expected_response["map_output"] = (
expected_response["map_output"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
)

pd.testing.assert_frame_equal(first_response, second_response)
pd.testing.assert_frame_equal(first_response, expected_response)
Expand Down Expand Up @@ -562,7 +566,7 @@ def test_disable_operator_cache(setup_models, model):
"Chemical Kinetics and Catalysis",
"Transport Phenomena and Separations",
],
"_map": [
"map_output": [
"Process Dynamics and Control",
"Advanced Optimization Techniques in Engineering",
"Reaction Kinetics and Mechanisms",
Expand All @@ -575,25 +579,33 @@ def test_disable_operator_cache(setup_models, model):
user_instruction = "What is a similar course to {Course Name}. Please just output the course name."

first_response = df.sem_map(user_instruction)
first_response["_map"] = first_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
first_response["map_output"] = first_response["map_output"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
assert lm.stats.total_usage.operator_cache_hits == 0

second_response = df.sem_map(user_instruction)
second_response["_map"] = second_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
second_response["map_output"] = (
second_response["map_output"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
)
assert lm.stats.total_usage.operator_cache_hits == 0

pd.testing.assert_frame_equal(first_response, second_response)

# Now enable operator cache.
lotus.settings.configure(enable_cache=True)
first_responses = df.sem_map(user_instruction)
first_responses["_map"] = first_responses["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
first_responses["map_output"] = (
first_responses["map_output"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
)
assert lm.stats.total_usage.operator_cache_hits == 0
second_responses = df.sem_map(user_instruction)
second_responses["_map"] = second_responses["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
second_responses["map_output"] = (
second_responses["map_output"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
)
assert lm.stats.total_usage.operator_cache_hits == 1

expected_response["_map"] = expected_response["_map"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
expected_response["map_output"] = (
expected_response["map_output"].str.replace(r"[^a-zA-Z\s]", "", regex=True).str.lower()
)

pd.testing.assert_frame_equal(first_responses, second_responses)
pd.testing.assert_frame_equal(first_responses, expected_response)
Expand Down
2 changes: 1 addition & 1 deletion .pre-commit-config.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -13,7 +13,7 @@ repos:
args: ["--config-file", "mypy.ini"]
additional_dependencies:
- types-setuptools
- litellm>=1.51.0
- litellm>=1.61.0
- numpy>=1.25.0
- pandas>=2.0.0
- sentence-transformers>=3.0.1
Expand Down
2 changes: 1 addition & 1 deletion docs/requirements-docs.txt
Original file line number Diff line number Diff line change
Expand Up @@ -3,7 +3,7 @@ sphinx-rtd-theme==2.0.0

backoff==2.2.1
faiss-cpu==1.8.0.post1
litellm==1.51.0
litellm==1.61.0
numpy==1.26.4
pandas==2.2.2
sentence-transformers==3.0.1
Expand Down
32 changes: 32 additions & 0 deletions examples/op_examples/map_multi_col.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,32 @@
import logging

import pandas as pd
from pydantic import BaseModel

import lotus
from lotus.models import LM

lm = LM(model="gpt-4o-mini")

lotus.settings.configure(lm=lm)
data = {
"Course Name": [
"Probability and Random Processes",
"Optimization Methods in Engineering",
"Digital Design and Integrated Circuits",
"Computer Security",
]
}
df = pd.DataFrame(data)
user_instruction = (
"What is a similar course to {Course Name}. Also give a study plan for the similar course. Be concise."
)


class Course(BaseModel):
new_course_name: str
new_course_study_plan: str


df = df.sem_map(user_instruction, response_format=Course)
print(df)
10 changes: 8 additions & 2 deletions lotus/sem_ops/postprocessors.py
Original file line number Diff line number Diff line change
@@ -1,5 +1,7 @@
import json

from pydantic import BaseModel

import lotus
from lotus.types import (
SemanticExtractPostprocessOutput,
Expand Down Expand Up @@ -37,7 +39,9 @@ def map_postprocess_cot(llm_answers: list[str]) -> SemanticMapPostprocessOutput:
return SemanticMapPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations)


def map_postprocess(llm_answers: list[str], cot_reasoning: bool = False) -> SemanticMapPostprocessOutput:
def map_postprocess(
llm_answers: list[str], response_format: type[BaseModel], cot_reasoning: bool = False
) -> SemanticMapPostprocessOutput:
"""
Postprocess the output of the map operator.

Expand All @@ -51,7 +55,9 @@ def map_postprocess(llm_answers: list[str], cot_reasoning: bool = False) -> Sema
if cot_reasoning:
return map_postprocess_cot(llm_answers)

outputs: list[str] = llm_answers
outputs: list[dict[str, str]] = [
response_format.model_validate_json(llm_answer).model_dump() for llm_answer in llm_answers
]
explanations: list[str | None] = [None] * len(llm_answers)
return SemanticMapPostprocessOutput(raw_outputs=llm_answers, outputs=outputs, explanations=explanations)

Expand Down
27 changes: 22 additions & 5 deletions lotus/sem_ops/sem_map.py
Original file line number Diff line number Diff line change
@@ -1,6 +1,7 @@
from typing import Any, Callable

import pandas as pd
from pydantic import BaseModel

import lotus
from lotus.cache import operator_cache
Expand All @@ -15,7 +16,8 @@ def sem_map(
docs: list[dict[str, Any]],
model: lotus.models.LM,
user_instruction: str,
postprocessor: Callable[[list[str], bool], SemanticMapPostprocessOutput] = map_postprocess,
response_format: type[BaseModel],
postprocessor: Callable[[list[str], type[BaseModel], bool], SemanticMapPostprocessOutput] = map_postprocess,
examples_multimodal_data: list[dict[str, Any]] | None = None,
examples_answers: list[str] | None = None,
cot_reasoning: list[str] | None = None,
Expand Down Expand Up @@ -55,10 +57,10 @@ def sem_map(
show_safe_mode(estimated_cost, estimated_LM_calls)

# call model
lm_output: LMOutput = model(inputs, progress_bar_desc=progress_bar_desc)
lm_output: LMOutput = model(inputs, progress_bar_desc=progress_bar_desc, response_format=response_format)

# post process results
postprocess_output = postprocessor(lm_output.outputs, strategy in ["cot", "zs-cot"])
postprocess_output = postprocessor(lm_output.outputs, response_format, strategy in ["cot", "zs-cot"])
lotus.logger.debug(f"raw_outputs: {lm_output.outputs}")
lotus.logger.debug(f"outputs: {postprocess_output.outputs}")
lotus.logger.debug(f"explanations: {postprocess_output.explanations}")
Expand All @@ -72,6 +74,10 @@ def sem_map(
)


class DefaultResponseFormat(BaseModel):
map_output: str


@pd.api.extensions.register_dataframe_accessor("sem_map")
class SemMapDataframe:
"""DataFrame accessor for semantic map."""
Expand All @@ -89,14 +95,15 @@ def _validate(obj: pd.DataFrame) -> None:
def __call__(
self,
user_instruction: str,
postprocessor: Callable[[list[str], bool], SemanticMapPostprocessOutput] = map_postprocess,
postprocessor: Callable[[list[str], type[BaseModel], bool], SemanticMapPostprocessOutput] = map_postprocess,
return_explanations: bool = False,
return_raw_outputs: bool = False,
suffix: str = "_map",
examples: pd.DataFrame | None = None,
strategy: str | None = None,
safe_mode: bool = False,
progress_bar_desc: str = "Mapping",
response_format: type[BaseModel] | None = None,
) -> pd.DataFrame:
"""
Applies semantic map over a dataframe.
Expand Down Expand Up @@ -140,6 +147,13 @@ def __call__(
return_explanations = True
cot_reasoning = examples["Reasoning"].tolist()

# Create default response format if none provided
found_response_format: type[BaseModel]
if response_format is None:
found_response_format = DefaultResponseFormat
else:
found_response_format = response_format

output = sem_map(
multimodal_data,
lotus.settings.lm,
Expand All @@ -151,10 +165,13 @@ def __call__(
strategy=strategy,
safe_mode=safe_mode,
progress_bar_desc=progress_bar_desc,
response_format=found_response_format,
)

new_df = self._obj.copy()
new_df[suffix] = output.outputs
for col in found_response_format.model_fields.keys():
new_df[col] = [x[col] for x in output.outputs]

if return_explanations:
new_df["explanation" + suffix] = output.explanations
if return_raw_outputs:
Expand Down
4 changes: 2 additions & 2 deletions lotus/types.py
Original file line number Diff line number Diff line change
Expand Up @@ -49,14 +49,14 @@ class LogprobsForFilterCascade:
@dataclass
class SemanticMapPostprocessOutput:
raw_outputs: list[str]
outputs: list[str]
outputs: list[dict[str, str]]
Copy link
Copy Markdown
Collaborator Author

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

The post-processed output will be a list[dict[str, str]] now since it will be like

[{'out_col_1': 'out_val_1', 'out_col_2': 'out_val_2'}, {'out_col_1': ...}`]

explanations: list[str | None]


@dataclass
class SemanticMapOutput:
raw_outputs: list[str]
outputs: list[str]
outputs: list[dict[str, str]]
explanations: list[str | None]


Expand Down
2 changes: 1 addition & 1 deletion pyproject.toml
Original file line number Diff line number Diff line change
Expand Up @@ -24,7 +24,7 @@ classifiers = [
dependencies = [
"backoff>=2.2.1,<3.0.0",
"faiss-cpu>=1.8.0.post1,<2.0.0",
"litellm>=1.51.0,<2.0.0",
"litellm>=1.61.0,<2.0.0",
"numpy>=1.25.0,<2.0.0",
"pandas>=2.0.0,<3.0.0",
"sentence-transformers>=3.0.1,<4.0.0",
Expand Down
2 changes: 1 addition & 1 deletion requirements.txt
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
backoff==2.2.1
faiss-cpu==1.8.0.post1
litellm==1.51.0
litellm==1.61.0
numpy==1.26.4
pandas==2.2.2
sentence-transformers==3.0.1
Expand Down