import json
import re
import nltk
import string
from collections import Counter
import torch
from elemeta.nlp.extractors.high_level.regex_match_count import RegexMatchCount
from elemeta.nlp.extractors.high_level.word_regex_matches_count import (
WordRegexMatchesCount,
)
from elemeta.nlp.extractors.low_level.abstract_metafeature_extractor import (
AbstractMetafeatureExtractor,
)
from elemeta.nlp.metafeature_extractors_runner import MetafeatureExtractorsRunner
from huggingface_hub import hf_hub_download
from llama_cpp import Llama
from langchain_community.chat_models import ChatLlamaCpp
from langchain_openai import ChatOpenAI
nltk.download("punkt_tab")
NO_ANS = "[CLS]"
INVALID_QUESTION = -1
# pylint:disable=too-many-boolean-expressions
[docs]
def load_json(path):
"""
Load a json file from a given path.
Args:
path (str) : path to the json file
Returns:
dict: dictionary of the json file
"""
with open(path) as json_file:
o_file = json_file.read()
return json.loads(o_file)
[docs]
def filter_questions(exp_ans, pred_ans):
"""
Check if the expected answer and the predicted answer are the same.
Args:
exp_ans (str) : expected answer
pred_ans (str) : predicted answer
Returns:
str: "VALID" if the answers are the same, "NO MATCH" otherwise
"""
if pred_ans == NO_ANS:
return "NO MATCH"
if clean_text(exp_ans) != clean_text(pred_ans):
return "NO MATCH"
return "VALID"
[docs]
def clean_text(text):
"""
Clean a text by removing punctuation and (some) stopwords.
Args:
text (str) : text to clean
Returns:
str: cleaned text
"""
# TODO: improve
# TODO: add support to french language
text = text.lower()
text = text.translate(str.maketrans("", "", string.punctuation))
text = re.sub(r"\b(a|an|the|in|our)\b", " ", text)
return re.sub(" +", " ", text).strip()
[docs]
def raw_f1_score(a_gold, a_pred):
"""
Compute the raw F1 score between two answers.
Args:
a_gold (str) : expected answer
a_pred (str) : predicted answer
Returns:
float : F1 score
"""
if a_pred == "":
return 0
gold_toks = clean_text(a_gold).split()
pred_toks = clean_text(a_pred).split()
common = Counter(gold_toks) & Counter(pred_toks)
num_same = sum(common.values())
if num_same == 0:
return 0
precision = 1.0 * num_same / len(pred_toks)
recall = 1.0 * num_same / len(gold_toks)
f1_score = (2 * precision * recall) / (precision + recall)
return f1_score
[docs]
def non_personal(question, nlp, lan="en"):
"""
Check if a question contains personal pronouns.
Args:
question (str) : question to check
nlp (spacy.lang) : spacy language model
lan (str) : language of the question. Defaults to "en".
Returns:
bool: True if the question does not contain personal pronouns, False otherwise
"""
question_tok = nlp(question)
for tok in question_tok:
if tok.dep_ == "nsubj" and lan == "en":
if (
tok.text.lower() == "i" or tok.text.lower() == "you"
): # TODO: add support to french language
return False
elif tok.dep_ == "poss" and lan == "en":
if (
tok.text.lower() == "my" or tok.text.lower() == "your"
): # TODO: add support to french language
return False
# french
elif tok.dep_ == "nsubj" and lan == "fr":
if (
tok.text.lower() == "je"
or tok.text.lower() == "tu"
or tok.text.lower() == "vous"
):
return False
elif tok.dep_ == "poss" and lan == "fr":
if tok.text.lower() in [
"mon",
"ton",
"votre",
"ma",
"ta",
"vos",
"mes",
"tes",
]:
return False
return True
[docs]
def get_llama_model(
repo_id: str = "TheBloke/Llama-2-7b-Chat-GGUF",
filename: str = "llama-2-7b-chat.Q2_K.gguf",
model_path=False,
):
"""
Download and return a Llama model from HuggingFace Hub.
Args:
repo_id (str) : HuggingFace Hub repo id of the model. Defaults to "TheBloke/Llama-2-7b-Chat-GGUF".
filename (str) : model filename to download. Defaults to "llama-2-7b-chat.Q2_K.gguf".
model_path (str) : path to the model locally to avoid downloading. Defaults to False.
Returns:
llama_cpp.Llama: Llama model
"""
if not model_path:
model_path = hf_hub_download(repo_id, filename)
if torch.cuda.is_available():
lcpp_llm = Llama(
model_path=model_path,
main_gpu=0,
n_gpu_layers=40, # check this
n_batch=1024,
logits_all=True,
n_ctx=1024,
device="cuda",
)
else:
lcpp_llm = Llama(
model_path=model_path,
logits_all=True,
n_ctx=1024,
)
return lcpp_llm
[docs]
def get_langchain_llama_model(
repo_id: str = "TheBloke/Llama-2-7b-Chat-GGUF",
filename: str = "llama-2-7b-chat.Q2_K.gguf",
model_path=False,
):
"""
Download and return a LlamaCPP model from LangChain, loaded from the HuggingFace Hub.
Args:
repo_id (str) : HuggingFace Hub repo id of the model. Defaults to "TheBloke/Llama-2-7b-Chat-GGUF".
filename (str) : model filename to download. Defaults to "llama-2-7b-chat.Q2_K.gguf".
model_path (str) : path to the model locally to avoid downloading. Defaults to False.
Returns:
langchain_community.chat_models.ChatLlamaCpp: LlamaCPP model from LangChain.
"""
if not model_path:
model_path = hf_hub_download(repo_id, filename)
if torch.cuda.is_available():
lcpp_llm = ChatLlamaCpp(
model_path=model_path,
n_gpu_layers=40, # check this
n_batch=1024,
logits_all=True,
logprobs=1,
n_ctx=1024,
device="cuda",
)
else:
lcpp_llm = ChatLlamaCpp(
model_path=model_path,
logits_all=True,
logprobs=1,
n_ctx=1024,
)
return lcpp_llm
[docs]
def get_langchain_gpt_model(version="gpt-3.5-turbo-0125"):
"""
Return a GPT model from Langchain OpenAI.
Args:
version (str) : model version
Returns:
langchain_openai.ChatOpenAI: GPT model from LangChain OpenAI.
"""
return ChatOpenAI(model=version)
[docs]
def check_list_type(array: list, list_type: type):
"""
Check if an array is a list of a given type.
Args:
array (list): array to check
list_type (type): type to check
Returns:
bool: True if the array is a list of the given type, False otherwise
"""
if not isinstance(array, list):
return False
return all(isinstance(item, list_type) for item in array)
# pylint:disable=invalid-name