
import mediapipe as mp
import numpy as np
from scipy import ndimage
from analysis_pipeline import classify_shot
import copy

mp_pose = mp.solutions.pose
mp_drawing = mp.solutions.drawing_utils

def detect_shot_times(lw_speed, rw_speed):
    swing_speed = np.amax(np.stack([lw_speed, rw_speed]), axis=0)
    downsample_factor = 4
    swing_speed = ndimage.zoom(swing_speed, 1/downsample_factor)
    swing_speed = ndimage.gaussian_filter(swing_speed, sigma=2)
    shot_threshold = np.median(swing_speed) + 0.5 * np.std(swing_speed)
    shot_times = ((np.where(np.diff(np.sign(swing_speed-shot_threshold)))[0] + 1) * downsample_factor).tolist()
    if swing_speed[0] >= shot_threshold:
        shot_times.insert(0, 0)
    if swing_speed[-1] >= shot_threshold:
        shot_times.append(len(lw_speed)-1)
    
    shot_intervals = []
    for i in range(0, len(shot_times)-1, 2):
        shot_intervals.append((shot_times[i], shot_times[i+1]))
    return shot_intervals

def calc_joint_speed(frames):
    frames_np = np.array(frames)
    speed_x = np.gradient(frames_np[..., 0])
    speed_y = np.gradient(frames_np[..., 1])
    speed_z = np.gradient(frames_np[..., 2])
    return np.sqrt(speed_x**2 + speed_y**2 + speed_z**2)

def calc_racket_spin_rate(racket_angle):
    angle_change = np.diff(racket_angle)
    spin_rate = np.abs(angle_change)
    return spin_rate

def calc_shot_consistency(wrist_speeds):
    return np.std(wrist_speeds)

def calc_racket_spin_type(racket_angle):
    angle_change = np.diff(racket_angle)
    spin_type = []
    for i in range(1, len(angle_change)):
        if abs(angle_change[i]) < 0.1:
            spin_type.append('flat')
        elif angle_change[i] > 0:
            spin_type.append('topspin')
        else:
            spin_type.append('slice')
    if not spin_type:
        spin_type = ['flat'] * len(racket_angle)
    return spin_type

def generate_feedback(shot_analysis):
    feedback = []

    for i in range(len(shot_analysis["classifications"])):
        classification = shot_analysis["classifications"][i]
        consistency = shot_analysis["shot_consistency"][i]
        speed = shot_analysis["speeds"][i]
        spin_type_left = shot_analysis["spin_type_left"][i]
        spin_type_right = shot_analysis["spin_type_right"][i]
        spin_rate_left = shot_analysis["spin_rate_left"][i]
        spin_rate_right = shot_analysis["spin_rate_right"][i]
        
        # Feedback based on consistency
        if consistency > 0.2:
            feedback.append(f"Try to make your shot more consistent. A more stable wrist speed will help.")
        else:
            feedback.append(f"Good job maintaining consistency in your shot!")

        # Feedback based on speed
        if speed < 2:
            feedback.append(f"Consider increasing your swing speed for more power.")
        else:
            feedback.append(f"Great power! Keep up the good swing speed.")

        # Feedback based on spin type
        if 'topspin' in spin_type_left or 'topspin' in spin_type_right:
            feedback.append(f"Excellent job on generating topspin! Keep the racket face slightly closed for more.")
        if 'slice' in spin_type_left or 'slice' in spin_type_right:
            feedback.append(f"Focus on making a cleaner slice shot by angling your racket more.")
        if 'flat' in spin_type_left and 'flat' in spin_type_right:
            feedback.append(f"Try adding more spin by increasing wrist angle during the shot.")

        # Feedback on spin rate
        if spin_rate_left < 0.5 and spin_rate_right < 0.5:
            feedback.append(f"Try generating more spin by adjusting the angle of your racket at impact.")
        else:
            feedback.append(f"Good spin rate! Keep focusing on maintaining racket control.")

    return feedback

def analyse_shots(joint_frames):
    lw_speed, rw_speed = calc_joint_speed(joint_frames["left_wrist"]), calc_joint_speed(joint_frames["right_wrist"])
    shot_intervals = detect_shot_times(lw_speed, rw_speed)

    shot_analysis = {}
    shot_analysis["intervals"] = []
    shot_analysis["classifications"] = []
    shot_analysis["speeds"] = []
    shot_analysis["hands"] = []
    shot_analysis["spin_type_left"] = []
    shot_analysis["spin_type_right"] = []
    shot_analysis["shot_consistency"] = []
    shot_analysis["spin_rate_left"] = []
    shot_analysis["spin_rate_right"] = []

    for start_t, end_t in shot_intervals:
        shot_analysis["intervals"].append((start_t, end_t))
        shot_joint_frames = copy.deepcopy(joint_frames)
        for joint in shot_joint_frames.keys():
            shot_joint_frames[joint] = shot_joint_frames[joint][start_t:end_t+1]
        
        classification = classify_shot(shot_joint_frames)
        shot_analysis["classifications"].append(classification)

        racket_angle_left = np.array([frame[0] for frame in shot_joint_frames["left_wrist"]])
        racket_angle_right = np.array([frame[0] for frame in shot_joint_frames["right_wrist"]])
        spin_type_left = calc_racket_spin_type(racket_angle_left)
        spin_type_right = calc_racket_spin_type(racket_angle_right)
        shot_analysis["spin_type_left"].append(spin_type_left)
        shot_analysis["spin_type_right"].append(spin_type_right)

        spin_rate_left = calc_racket_spin_rate(racket_angle_left)
        spin_rate_right = calc_racket_spin_rate(racket_angle_right)
        shot_analysis["spin_rate_left"].append(np.mean(spin_rate_left))
        shot_analysis["spin_rate_right"].append(np.mean(spin_rate_right))

        shot_consistency_left = calc_shot_consistency(lw_speed[start_t:end_t+1])
        shot_consistency_right = calc_shot_consistency(rw_speed[start_t:end_t+1])
        shot_analysis["shot_consistency"].append(np.mean([shot_consistency_left, shot_consistency_right]))

        mean_lw_speed, mean_rw_speed = np.mean(lw_speed[start_t:end_t+1]), np.mean(rw_speed[start_t:end_t+1])
        
        if mean_lw_speed > mean_rw_speed:
            hand = "left"
            speed = mean_lw_speed
        else:
            hand = "right"
            speed = mean_rw_speed
        shot_analysis["speeds"].append(speed)
        shot_analysis["hands"].append(hand)

    feedback = generate_feedback(shot_analysis)
    shot_analysis["feedback"] = feedback  # Add feedback to the analysis

    return shot_analysis
