import streamlit as st
from langchain_community.vectorstores import FAISS
from sentence_transformers import SentenceTransformer
from langchain.embeddings.base import Embeddings
from langchain.prompts import PromptTemplate
from langchain.chains import RetrievalQA
from dotenv import load_dotenv
import os
from groq import Groq

# Load environment variables
load_dotenv()

# Initialize Groq client
client = Groq(
    api_key=os.getenv("GROQ_API_KEY"),
)

def chat_with_groq(prompt: str, model: str = "llama-3.3-70b-versatile") -> str:
    try:
        # Create a chat completion
        chat_completion = client.chat.completions.create(
            messages=[
                {"role": "system", "content": "You are an Ayurvedic expert with deep knowledge of Ayurvedic practices, remedies, and diagnostics."},
                {"role": "user", "content": prompt}
            ],
            model=model,
        )
        # Return the response content
        return chat_completion.choices[0].message.content
    except Exception as e:
        return f"An error occurred: {e}"

# Define a custom embedding wrapper for LangChain
class SentenceTransformerEmbeddings(Embeddings):
    def __init__(self, model_name="all-MiniLM-L6-v2"):
        self.model = SentenceTransformer(model_name)
    
    def embed_documents(self, texts):
        return self.model.encode(texts, show_progress_bar=True)
    
    def embed_query(self, text):
        return self.model.encode([text], show_progress_bar=False)[0]

# Path to FAISS index
faiss_index_path = "faiss_index"  # Path where the FAISS index was saved in backend

# Load FAISS Index with dangerous deserialization enabled
embedding_model = SentenceTransformerEmbeddings()  # Use the same embedding model as in backend
db = FAISS.load_local(faiss_index_path, embedding_model, allow_dangerous_deserialization=True)

# Define Prompt Template
prompt_template = PromptTemplate(
    input_variables=["context", "question"],
    template=(
        "You are an Ayurvedic expert with deep knowledge of Ayurvedic practices, remedies, and diagnostics. "
        "Use the provided Ayurvedic context to answer the question thoughtfully and accurately.\n\n"
        "Context:\n{context}\n\n"
        "Question:\n{question}\n\n"
        "Answer as an Ayurvedic expert:"
    )
)

# Define the RetrievalQA Chain
class GroqRetrievalQA:
    def __init__(self, retriever, prompt_template):
        self.retriever = retriever
        self.prompt_template = prompt_template
    
    def run(self, query):
        docs = self.retriever.get_relevant_documents(query)
        context = "\n".join([doc.page_content for doc in docs])
        prompt = self.prompt_template.format(context=context, question=query)
        return chat_with_groq(prompt)

qa_chain = GroqRetrievalQA(db.as_retriever(), prompt_template)

# Streamlit UI
st.set_page_config(page_title="Ayurveda Chatbot", layout="wide")
st.title("Ayurveda Chatbot")

st.subheader("Ask your Ayurvedic Question")
query = st.text_input("Enter your query:")
if query:
    with st.spinner("Retrieving answer..."):
        response = qa_chain.run(query)
        st.markdown(f"### Answer:\n{response}")
