import numpy as np
import cv2
import json
import sys
import mediapipe as mp
from os.path import join
from analysis_pipeline import extract_joint_frames, PoseNotFoundError, analyse_shots, pose2bvh

mp_pose = mp.solutions.pose
mp_drawing = mp.solutions.drawing_utils

class AnalysisFailedError(Exception):
    pass

def classify_spin_rate(spin_rate):
    """Classify the spin rate into flat, topspin, or slice"""
    if np.mean(spin_rate) < 0.1:
        return "flat"
    elif np.mean(spin_rate) > 0.3:
        return "topspin"
    else:
        return "slice"

def generate_feedback(shot_analysis):
    """Generate actionable feedback based on shot analysis data"""
    feedback = []
    for i in range(len(shot_analysis["classifications"])):
        classification = shot_analysis["classifications"][i]
        consistency = shot_analysis["shot_consistency"][i]
        speed = shot_analysis["speeds"][i]
        spin_type_left = shot_analysis["spin_type_left"][i]
        spin_type_right = shot_analysis["spin_type_right"][i]
        spin_rate_left = shot_analysis["spin_rate_left"][i]
        spin_rate_right = shot_analysis["spin_rate_right"][i]

        # Feedback based on consistency
        if consistency > 0.2:
            feedback.append(f"Try to make your shot more consistent. A more stable wrist speed will help.")
        else:
            feedback.append(f"Good job maintaining consistency in your shot!")

        # Feedback based on speed
        if speed < 2:
            feedback.append(f"Consider increasing your swing speed for more power.")
        else:
            feedback.append(f"Great power! Keep up the good swing speed.")

        # Feedback based on spin type
        if spin_type_left == 'topspin' or spin_type_right == 'topspin':
            feedback.append(f"Excellent job on generating topspin! Keep the racket face slightly closed for more.")
        elif spin_type_left == 'slice' or spin_type_right == 'slice':
            feedback.append(f"Focus on making a cleaner slice shot by angling your racket more.")
        elif spin_type_left == 'flat' and spin_type_right == 'flat':
            feedback.append(f"Try adding more spin by increasing wrist angle during the shot.")

        # Feedback on spin rate
        if spin_rate_left < 0.5 and spin_rate_right < 0.5:
            feedback.append(f"Try generating more spin by adjusting the angle of your racket at impact.")
        else:
            feedback.append(f"Good spin rate! Keep focusing on maintaining racket control.")
    
    return feedback

def analyse_video(video_path, out_dir):
    """Perform 3D shot analysis on video feed

    Parameters:
        video_path, -- path to video file to be analysed
        out_dir -- path to directory to output analysis results in (.json file, annotated video and clips, and .bvh files) 
    """
    cap = cv2.VideoCapture(video_path)
    fps = cap.get(cv2.CAP_PROP_FPS)
    total_frames = cap.get(cv2.CAP_PROP_FRAME_COUNT)
    total_duration = total_frames / fps  # total duration in seconds
    
    frames = []
    
    while cap.isOpened():
        ret, frame = cap.read()
        if not ret:
            break
        frames.append(frame)
    cap.release()
    
    if len(frames) == 0:
        raise AnalysisFailedError("Video analysis failed because no frames could be read from the video.")

    try:
        joint_frames, mp_landmarks = extract_joint_frames(frames)
    except PoseNotFoundError:
        raise AnalysisFailedError("Video analysis failed because no pose could be detected.")

    shot_analysis = analyse_shots(joint_frames)
    
    # Ensure that spin_rate_left and spin_rate_right are in the shot_analysis dictionary
    if 'spin_rate_left' not in shot_analysis or 'spin_rate_right' not in shot_analysis:
        raise AnalysisFailedError("'spin_rate_left' or 'spin_rate_right' missing from shot analysis data.")

    shots = list(zip(
                shot_analysis["intervals"], 
                shot_analysis["classifications"], 
                shot_analysis["speeds"],
                shot_analysis["hands"],
                shot_analysis["spin_rate_left"],  # added
                shot_analysis["spin_rate_right"],  # added
                shot_analysis["shot_consistency"]  # added
            ))

    # assume shot intervals less than 0.6s long are invalid detections, so ignore these
    for s in shots:
        interval = s[0]
        length = interval[1] - interval[0]
        if length / fps < 0.6:
            shots.remove(s)

    # Classify the spin rate for each shot as flat, topspin, or slice
    for i, shot in enumerate(shots):
        spin_rate_left = shot[4]
        spin_rate_right = shot[5]
        
        # Classify each spin rate and add it to the shot info
        shot_type_left = classify_spin_rate(spin_rate_left)
        shot_type_right = classify_spin_rate(spin_rate_right)
        
        # Store classified spin type (could be either flat, topspin, or slice)
        shots[i] = shot + (shot_type_left, shot_type_right)
    
    # Generate feedback based on shot analysis
    feedback = generate_feedback(shot_analysis)

    analysis_json = {
        "fps": fps,
        "total_duration": total_duration,  # Add total duration of the video
        "shots": 
        [
            {
                "start_frame_idx": int(start_t),
                "end_frame_idx": int(end_t),
                "classification": classification,
                "accuracy": float(confidence),
                "speed": speed,
                "hand": hand,
                "spin_rate_left": (spin_rate_left.tolist() if isinstance(spin_rate_left, np.ndarray) else spin_rate_left),  # Convert to list if ndarray
                "spin_rate_right": (spin_rate_right.tolist() if isinstance(spin_rate_right, np.ndarray) else spin_rate_right),  # Convert to list if ndarray
                "shot_consistency": (shot_consistency.tolist() if isinstance(shot_consistency, np.ndarray) else shot_consistency),  # Convert to list if ndarray
                "spin_type_left": shot_type_left,  # Added spin type classification for left hand
                "spin_type_right": shot_type_right,  # Added spin type classification for right hand
                "feedback": feedback[i]  # Add feedback for each shot
            } for ((start_t, end_t), (classification, confidence), speed, hand, spin_rate_left, spin_rate_right, shot_consistency, shot_type_left, shot_type_right) 
            in shots
        ]
    }
    
    # Save the analysis result as a JSON file
    with open(join(out_dir, "shot_analysis.json"), "w") as json_file:
        json.dump(analysis_json, json_file, indent=2)
    
    # Return the analysis JSON for further processing or debugging
    return analysis_json

if __name__ == "__main__":
    args = sys.argv[1:]
    if len(args) != 2:
        print("Incorrect number of arguments. First argument should be a path to a video. Second argument should be a path to an output directory.")
    else:
        analyse_video(args[0], args[1])
