import yt_dlp
import asyncio
from fastapi import FastAPI, HTTPException, Query
from yt_dlp import YoutubeDL
import torch
import subprocess
import numpy as np
import os, subprocess
import whisperx
import gc
import subprocess
from groq import Groq
import uvicorn
from chat_module import chat_with_groq
from helpers import *
from dotenv import load_dotenv
import json


# Load environment variables
load_dotenv()
# fast api initilaization

PORT = os.getenv("PORT")
PORT = int(PORT)
app = FastAPI()

save_dir = "assets/shorts"



# if torch.cuda.is_available():
#     device = "cuda"
#     batch_size = 8  
#     compute_type = "float16"
# else:
#     device = "cpu"
#     batch_size = 4  
#     compute_type = "int8"  

device = "cpu"
batch_size = 16 
compute_type = "float32" 

# Download You Tube Video
async def download_yt_video(link: str) -> str:
    try:
        output_path = "assets/videos"
        ydl_opts = {
            'format': 'bestvideo+bestaudio/best',  
            'outtmpl': f"{output_path}/%(title)s.%(ext)s",  
            'merge_output_format': 'mp4',
        }

        loop = asyncio.get_event_loop()
        with yt_dlp.YoutubeDL(ydl_opts) as ydl:
            info_dict = await loop.run_in_executor(None, lambda: ydl.extract_info(link, download=True))
            video_file = ydl.prepare_filename(info_dict).replace('.webm', '.mp4').replace('.mkv', '.mp4')

            print(f"Downloaded video file: {video_file}") 
            return video_file

    except Exception as e:
        raise HTTPException(status_code=500, detail=f"Download failed: {str(e)}")


# Download YouTube Audio
def download_youtube_audio(youtube_url: str) -> str:
    try:
        output_path="assets/audios"
        ydl_opts = {
            'format': 'bestaudio/best',
            'postprocessors': [{
                'key': 'FFmpegExtractAudio',
                'preferredcodec': 'mp3',
                'preferredquality': '192',
            }],
            'outtmpl': f'{output_path}/%(title)s.%(ext)s',
            'quiet': True
        }

        with YoutubeDL(ydl_opts) as ydl:
            info_dict = ydl.extract_info(youtube_url, download=True)
            audio_file = ydl.prepare_filename(info_dict).replace('.webm', '.mp3').replace('.m4a', '.mp3') 
            return audio_file
    except Exception as e:
        raise Exception(f"Error downloading audio: {str(e)}") 

def classify_text(text):
    prompt = f"""
    Given the following text, classify it as a "question" or "answer":
    Text: "{text}"
    Respond with only "question" or "answer".
    """
    api_key = "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"
    client = Groq(api_key=api_key)
    completion = client.chat.completions.create(
        model="llama3-8b-8192",  
        messages=[{"role": "user", "content": prompt}],  
        temperature=1,
        max_completion_tokens=1024,
        top_p=1,
        stream=True,
        stop=None,
    )
    response = ""
    for chunk in completion:
        response += chunk.choices[0].delta.content or ""  
    classification = response.strip().lower()
    gc.collect()
    return "question" if "question" in classification else "answer"
    



# def trim_video(input_video, output_video, start_time, end_time):
#     """ Uses FFmpeg to trim a video from start_time to end_time """
#     command = [
#         "ffmpeg", "-i", input_video, "-ss", str(start_time), "-to", str(end_time),
#         "-c:v", "libx264", "-c:a", "aac", "-strict", "experimental", "-y", output_video
#     ]
#     subprocess.run(command, stdout=subprocess.DEVNULL, stderr=subprocess.DEVNULL)

# Speaker Diarization & Transcription
def transcript_audio(audio_file_path,video_file_path):
    # Load WhisperX model
    model = whisperx.load_model("large-v2", device, compute_type=compute_type)
    audio = whisperx.load_audio(audio_file_path)
    video_file = video_file_path
    output_folder = save_dir


    # Transcription
    result = model.transcribe(audio, batch_size=batch_size)
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    del model

    # Alignment
    model_a, metadata = whisperx.load_align_model(language_code=result["language"], device=device)
    result = whisperx.align(result["segments"], model_a, metadata, audio, device, return_char_alignments=False)
    gc.collect()
    if torch.cuda.is_available():
        torch.cuda.empty_cache()
    del model_a

    # Speaker Diarization
    diarize_model = whisperx.DiarizationPipeline(use_auth_token="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", device=device)
    diarize_segments = diarize_model(audio)
    result = whisperx.assign_word_speakers(diarize_segments, result)

    #Merging consecutive segments of the same speaker
    merged_segments = []
    current_speaker = None
    current_text = ""
    start_time = None
    end_time = None

    for segment in result["segments"]:
        speaker = segment.get("speaker", "Unknown")
        text = segment["text"]
        if speaker == current_speaker:
            current_text += " " + text
            end_time = segment["end"]
        else:
            if current_speaker is not None:
                merged_segments.append({"speaker": current_speaker, "start": start_time, "end": end_time, "text": current_text})
            current_speaker = speaker
            current_text = text
            start_time = segment["start"]
            end_time = segment["end"]
    
    # Add last segment
    if current_speaker is not None:
        merged_segments.append({"speaker": current_speaker, "start": start_time, "end": end_time, "text": current_text})

    merged_segments_text = ""
    for seg in merged_segments:
        merged_segments_text += f"[{seg['start']:.2f} - {seg['end']:.2f}] {seg['speaker']}: {seg['text']}\n"

    prompt = f'''
    {merged_segments_text}
    identify the questions and answers in the text above and replace SPEAKER_01 and SPEAKER_00 with question or answer and keep only relavent question and answer pairs.and remove unnecessary text. because i want to reduce duration.
    Keep in mind the output should be the same as the input just replace SPEAKER_01 and SPEAKER_00 with question or answer and keep only relavent question and answer pairs.and remove unnecessary text. also return the timestamp mentioned in the input text.
    remove the greeting questions and answers.
    [720.17 - 722.31] Question: That's the end of the speaking test. 
    [722.91 - 723.57] Answer: Thank you so much.
    remove such type of pairs from the text. keep only relavent question and answer pairs.
    '''
    max_attempts = 5
    attempt = 0

    while attempt < max_attempts:
        qa_struct = chat_with_groq(prompt)
        # Export qa_struct to a text file
        print("QA Structured")
        qa_timeframes = extract_timeframes(qa_struct)
        if len(qa_timeframes) > 0:
            break
        attempt += 1

    if attempt == max_attempts:
        print("Failed to get valid QA timeframes after 5 attempts.")

    timestamp_pairs = []
    print("Timestamp Pairs")

    for idx, item in enumerate(qa_timeframes):
        question_start = item['question']['start_time']
        answer_end = item['answer']['end_time']
        question = item['question']['text']
        answer = item['answer']['text']
        
        # Append the item with idx_no
        timestamp_pairs.append({
            'idx_no': idx + 1,
            'start_time': question_start,
            'end_time': answer_end,
            'question': question,
            'answer': answer
        })

    duration_threshold = 45

    # Iterate until no timestamps exceed the threshold or max attempts are reached
    max_attempts = 3
    attempts = 0

    while attempts < max_attempts:
        long_shot_idx = [item["idx_no"] for item in timestamp_pairs if (item["end_time"] - item["start_time"]) > duration_threshold]
        print("Long Shot Indexes:", long_shot_idx)
        
        # If no long shots are found, break the loop
        if not long_shot_idx:
            break
        
        # Process each long shot
        for i in long_shot_idx:
            idx_answer = next((item["answer"] for item in timestamp_pairs if item["idx_no"] == i), None)
            if idx_answer:
                idx_no_answer = idx_answer.split(".")[:-2]
                item_answer = ".".join(idx_no_answer) + "."
                print("Item Answer:", item_answer)
                last_line = idx_no_answer[-1].strip() + "."
                end_time = check_line(last_line, result)
                if end_time:
                    for item in timestamp_pairs:
                        if item["idx_no"] == i:
                            item["end_time"] = end_time
                            item["answer"] = item_answer
        
        # Increment the attempt counter
        attempts += 1
        print(f"Attempt {attempts} completed.")

    # After 3 attempts, remove elements that still exceed the threshold
    if attempts == max_attempts:
        long_shot_idx = [item["idx_no"] for item in timestamp_pairs if (item["end_time"] - item["start_time"]) > duration_threshold]
        if long_shot_idx:
            timestamp_pairs = [item for item in timestamp_pairs if item["idx_no"] not in long_shot_idx]

    short_urls = trim_video(video_file, output_folder, timestamp_pairs)
    return timestamp_pairs, short_urls


    

    
    


@app.get("/")
async def root():
    return {"message": "Platform to get YouTube Video Shorts"}

@app.post("/process-video/")
async def video_process(link: str = Query(..., description="YouTube video URL")):
    video_path = await download_yt_video(link)
    audio_path=download_youtube_audio(link)
    print(video_path,audio_path)
    shorts_pair, shorts_url=transcript_audio(audio_file_path=audio_path,video_file_path=video_path)
    return {"status": "success", "message1": video_path,"message2":audio_path,"Q-A Segments":shorts_pair, "Shorts URL":shorts_url}

if __name__ == "__main__":
    uvicorn.run("app:app", host="0.0.0.0", port=PORT, reload=True)

