

import numpy as np
import tensorflow as tf
import os

model = tf.keras.models.load_model(os.path.join("analysis_pipeline/", "cnn_recognition_model1.h5"))
class_names = ['backhand', 'forehand', 'service', 'smash']

def classify_shot(joint_frames):
    """Classifies 3D pose animation into 'backhand', 'forehand', 'service', or 'smash' shot

    Parameters:
        joint_frames (dict) -- dictionary of 3D joint coordinate frames with joint names as the keys

    Returns:
        classification (str), -- shot classification
        confidence (float) -- confidence score
    """
    ordered_joints = ["head",
        "left_elbow",
        "left_foot",
        "left_wrist",
        "left_hip",
        "left_knee",
        "left_shoulder",
        "neck",
        "right_elbow",
        "right_foot",
        "right_wrist",
        "right_hip",
        "right_knee",
        "right_shoulder",
        "torso"
    ]
    pose_frames = [np.array(joint_frames[joint]) for joint in ordered_joints]
    pose_img = np.stack(pose_frames)
    pose_img = pose_img.swapaxes(0,1)

    skeleton_img = pose_img / np.abs(pose_img).max()
    skeleton_img = (skeleton_img+1)*127.5
    skeleton_img = tf.image.resize(skeleton_img[None, ...], [100, 15])[0, ...].numpy()

    predictions = model.predict(skeleton_img[None, ...])
    score = tf.nn.softmax(predictions[0])

    return class_names[np.argmax(score)], np.max(score)
