Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
4 changes: 3 additions & 1 deletion backend/console/heuristic_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -45,7 +45,8 @@ class FilterPropertyName:
property_methanol_perm = 'methanol permeability'
property_n2_perm = 'N_{2} permeability'
property_ch4_perm = 'CH_{4} permeability'

property_water_perm = 'water permeability'
# water_perm_hf_sel1k = 'water permeability'

def add_args(subparsers: _SubParsersAction):
parser: ArgumentParser = subparsers.add_parser(
Expand Down Expand Up @@ -114,6 +115,7 @@ def run(args: ArgumentParser):
'''

# #Query for sel1k
# log.note("Filter for select-1k")
# query = '''
# SELECT pt.id AS para_id FROM paper_texts pt
# JOIN filtered_papers fp ON fp.doi = pt.doi
Expand Down
7 changes: 4 additions & 3 deletions backend/console/llm_pipeline.py
Original file line number Diff line number Diff line change
Expand Up @@ -60,9 +60,10 @@ def run(args: ArgumentParser):
polyai.api.api_key = sett.LLMPipeline.polyai_key

else:
log.critical("Unrecognized api '{}' defined in the method {}",
method.api, args.method)
exit(1)
log.warn("No api being used.")
# log.critical("Unrecognized api '{}' defined in the method {}",
# method.api, args.method)
# exit(1)


para_filter_name = method.para_subset
Expand Down
2 changes: 2 additions & 0 deletions backend/console/ps_ner_filter.py
Original file line number Diff line number Diff line change
Expand Up @@ -46,6 +46,8 @@ class HeuristicFilterName:
n2_perm_ner_full = "property_n2_perm"
ch4_perm_ner_full = "property_ch4_perm"
methanol_perm_ner_full = "property_methanol_perm"
water_perm_ner_full = 'property_water_perm'
# water_perm_ner_sel1k = 'water_perm_hf_sel1k'


def add_args(subparsers: _SubParsersAction):
Expand Down
101 changes: 101 additions & 0 deletions backend/prompt_extraction/exllama_model.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,101 @@
from exllamav2 import ExLlamaV2, ExLlamaV2Config, ExLlamaV2Cache, ExLlamaV2Tokenizer
from exllamav2.generator import ExLlamaV2StreamingGenerator, ExLlamaV2Sampler
from transformers import AutoTokenizer
import time
import pylogg as log

class ExLlamaV2Model:
def __init__(self, model_directory, model_name, max_new_tokens=150, max_seq_len = 8000):
self.model_directory = model_directory
# self.template_path = template_path
self.model_name = model_name #option between [llama3, phi3]
self.max_new_tokens = max_new_tokens
self.max_seq_len = max_seq_len
self.model, self.cache, self.tokenizer, self.chat_tokenizer, self.config = self.initialize_model()
self.generator, self.settings = self.setup_generator()

def initialize_model(self):
"""Initializes the model, tokenizer, cache, and other settings."""
print(f"Loading model: {self.model_directory}")
# template = "".join([line.strip() for line in open(self.template_path)])
chat_tokenizer = AutoTokenizer.from_pretrained(self.model_directory, use_fast = False)

# Sanity Check
if not chat_tokenizer.chat_template:
raise ValueError("Chat template not specified in 'tokenizer_config.json'")

# chat_tokenizer.chat_template = template

config = ExLlamaV2Config(self.model_directory)
model = ExLlamaV2(config)
cache = ExLlamaV2Cache(model, max_seq_len = self.max_seq_len, lazy=True)
model.load_autosplit(cache, progress = True)
tokenizer = ExLlamaV2Tokenizer(config)

tokenizer.eos_token = ''
tokenizer.eos_token_id = 128009

return model, cache, tokenizer, chat_tokenizer, config

def setup_generator(self):
"""Sets up the generator with the given model, cache, and tokenizer."""
settings = ExLlamaV2Sampler.Settings()
settings.temperature = 0.85
settings.top_k = 50
settings.top_p = 0.8
settings.token_repetition_penalty = 1.01

generator = ExLlamaV2StreamingGenerator(self.model, self.cache, self.tokenizer)
generator.warmup()

generator.set_stop_conditions([self.chat_tokenizer.eos_token_id])

if self.model_name == 'llama3':
stop_conditions = [self.tokenizer.eos_token_id]
generator.set_stop_conditions(stop_conditions)
# elif self.model_name =='phi3':
# generator.set_stop_conditions([self.chat_tokenizer.eos_token_id])

return generator, settings

def generate_text(self, chat):
"""Generates text based on the chat history."""
prompt = self.chat_tokenizer.apply_chat_template(chat, tokenize=False, add_generation_prompt=True)
log.info(f"Prompt: {prompt}")
# input_ids = self.tokenizer.encode(prompt, add_bos=True)

input_ids = self.chat_tokenizer.apply_chat_template(chat, tokenize=True, add_generation_prompt=True, return_tensors="pt")

self.generator.begin_stream_ex(input_ids, self.settings, loras = None, decode_special_tokens = True)

# if self.model_name=='phi3':
generated_tokens = []
output = ''

while True:
res = self.generator.stream_ex()
chunk = res["chunk"]
output += chunk
eos = res["eos"]

generated_tokens += res['chunk_token_ids'][0].tolist()
print(chunk, end="", flush=True)

if eos or len(generated_tokens) == self.max_new_tokens:
print("\n")
break

return output

# generated_tokens = 0
# output = ''
# while True:
# res = self.generator.stream_ex()
# chunk = res["chunk"]
# output += chunk
# eos = res["eos"]
# generated_tokens += 1
# if eos or generated_tokens == self.max_new_tokens:
# break

# return output
58 changes: 53 additions & 5 deletions backend/prompt_extraction/prompt_extractor.py
Original file line number Diff line number Diff line change
Expand Up @@ -14,9 +14,17 @@
from backend.text.normalize import TextNormalizer
from backend.prompt_extraction.shot_selection import ShotSelector
from backend.postgres.orm import APIRequests, PaperTexts, ExtractionMethods
from backend.prompt_extraction.exllama_model import ExLlamaV2Model

log = pylogg.New('llm')

model_name = 'llama3'
max_seq_len = 4000
max_new_tokens = 512
model_directory= "/data/sonakshi/SynthesisRecipes/models/Llama-3-8B-Instruct-exl2"
model = ExLlamaV2Model(model_directory, max_new_tokens= max_new_tokens, max_seq_len = max_seq_len, model_name=model_name)


class LLMExtractor:
PROMPTS = [
"Extract all numbers in JSONL format with 'material', 'property', 'value', 'condition' columns.",
Expand Down Expand Up @@ -184,16 +192,52 @@ def _ask_llm(self, para : PaperTexts, prompt : str,
reqinfo.status = 'failed'
log.error("API request failed.")
else:
# reqinfo.status = 'done'
# try:
# reqinfo.response_obj = json.loads(str(output))
# str_output = output["choices"][0]["message"]["content"]
# reqtok = output["usage"]["prompt_tokens"]
# resptok = output["usage"]["completion_tokens"]
# reqinfo.response = str_output
# reqinfo.request_tokens = reqtok
# reqinfo.response_tokens = resptok
# reqinfo.status = 'ok'
reqinfo.status = 'done'
try:
reqinfo.response_obj = json.loads(str(output))
str_output = output["choices"][0]["message"]["content"]
reqtok = output["usage"]["prompt_tokens"]
resptok = output["usage"]["completion_tokens"]

# Assuming the output is in the expected format
if isinstance(output, dict) and "choices" in output and output["choices"]:
str_output = output["choices"][0]["message"]["content"]
reqtok = output["usage"]["prompt_tokens"]
resptok = output["usage"]["completion_tokens"]
else:
# If the expected structure is not present (i.e. using Exllama)
str_output = str(output)
reqtok = None
resptok = None

reqinfo.response = str_output
reqinfo.request_tokens = reqtok
reqinfo.response_tokens = resptok
reqinfo.status = 'ok'

except json.JSONDecodeError:
log.error("Failed to decode JSON from output.")
reqinfo.status = 'json decode error'
reqinfo.response_obj = str(output)
reqinfo.response = str(output)
reqinfo.request_tokens = None
reqinfo.response_tokens = None

except KeyError as err:
log.error("Key error when parsing output: {}", err)
reqinfo.status = 'key error'
reqinfo.response_obj = str(output)
reqinfo.response = str(output)
reqinfo.request_tokens = None
reqinfo.response_tokens = None

except Exception as err:
log.error("Failed to parse API output: {}", err)
reqinfo.status = 'output parse error'
Expand Down Expand Up @@ -227,14 +271,18 @@ def _make_request(self, messages : list[dict]) -> dict:
messages = messages
)
else:
raise NotImplementedError("Unknown API", self.api)
response = model.generate_text(messages)
# raise NotImplementedError("Unknown API", self.api)

return response

def _extract_data(self, response : dict) -> list[dict]:
""" Post process the LLM output and extract the embedded data. """
data = []
str_output = response["choices"][0]["message"]["content"]
try:
str_output = response["choices"][0]["message"]["content"]
except:
str_output = response #when using exllama
log.trace("Parsing LLM output: {}", str_output)

try:
Expand Down
16 changes: 8 additions & 8 deletions notebooks/Abstract-Extracted_Data.ipynb
Original file line number Diff line number Diff line change
Expand Up @@ -52,7 +52,7 @@
},
{
"cell_type": "code",
"execution_count": 8,
"execution_count": 5,
"metadata": {},
"outputs": [],
"source": [
Expand All @@ -62,7 +62,7 @@
},
{
"cell_type": "code",
"execution_count": 9,
"execution_count": 6,
"metadata": {},
"outputs": [
{
Expand Down Expand Up @@ -168,7 +168,7 @@
"[82 rows x 2 columns]"
]
},
"execution_count": 9,
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
Expand Down Expand Up @@ -196,17 +196,17 @@
},
{
"cell_type": "code",
"execution_count": 10,
"execution_count": 7,
"metadata": {},
"outputs": [],
"source": [
"df.save('notebooks/Abstract-Extracted_Data.csv')\n",
"propcount.save('notebooks/Abstract-Property_Count.csv')"
"df.save('data/Abstract-Extracted_Data.csv')\n",
"propcount.save('data/Abstract-Property_Count.csv')"
]
},
{
"cell_type": "code",
"execution_count": 11,
"execution_count": 8,
"metadata": {},
"outputs": [
{
Expand All @@ -217,7 +217,7 @@
"dtype: object"
]
},
"execution_count": 11,
"execution_count": 8,
"metadata": {},
"output_type": "execute_result"
}
Expand Down
695 changes: 657 additions & 38 deletions notebooks/Analyze-Select-1k.ipynb

Large diffs are not rendered by default.

Loading