Skip to content

Commit

Permalink
Add weights per feature dimension
Browse files Browse the repository at this point in the history
  • Loading branch information
GuilhermeGSousa committed Dec 22, 2024
1 parent 8dd6993 commit 24582f2
Show file tree
Hide file tree
Showing 4 changed files with 87 additions and 16 deletions.
8 changes: 5 additions & 3 deletions src/features/mm_feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -36,11 +36,13 @@ class MMFeature : public Resource {

virtual void display_data(const Ref<EditorNode3DGizmo>& p_gizmo, const Transform3D p_transform, const float* p_data) const {};

void normalize(float* p_data) const;
void denormalize(float* p_data) const;
float calculate_normalized_weight() const {
virtual float calculate_normalized_weight(int64_t p_feature_dim) const {
return weight / get_dimension_count();
}

void normalize(float* p_data) const;
void denormalize(float* p_data) const;

GETSET(float, weight, 1.0f);
GETSET(NormalizationMode, normalization_mode, Standard);
GETSET(PackedFloat32Array, means);
Expand Down
52 changes: 52 additions & 0 deletions src/features/mm_trajectory_feature.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -162,6 +162,24 @@ void MMTrajectoryFeature::display_data(const Ref<EditorNode3DGizmo>& p_gizmo, co
delete[] dernomalized_data;
}

float MMTrajectoryFeature::calculate_normalized_weight(int64_t p_feature_dim) const {

float weight = MMFeature::calculate_normalized_weight(p_feature_dim);

const uint32_t point_dim = _get_point_dimension_count();

const bool is_height = include_height && (p_feature_dim % point_dim) == 2;
const bool is_facing = include_facing && (p_feature_dim % point_dim) == (include_height ? 3 : 2);

if (is_height) {
weight *= height_weight;
} else if (is_facing) {
weight *= facing_weight;
}

return weight;
}

TypedArray<Dictionary> MMTrajectoryFeature::get_trajectory_points(const Transform3D& p_character_transform, const PackedFloat32Array& p_trajectory_data) const {
ERR_FAIL_COND_V(p_trajectory_data.is_empty(), TypedArray<Dictionary>());

Expand Down Expand Up @@ -189,14 +207,48 @@ TypedArray<Dictionary> MMTrajectoryFeature::get_trajectory_points(const Transfor
return result;
}

bool MMTrajectoryFeature::get_include_height() const {
return include_height;
}

void MMTrajectoryFeature::set_include_height(bool value) {
include_height = value;
notify_property_list_changed();
}

bool MMTrajectoryFeature::get_include_facing() const {
return include_facing;
}

void MMTrajectoryFeature::set_include_facing(bool value) {
include_facing = value;
notify_property_list_changed();
}

void MMTrajectoryFeature::_validate_property(PropertyInfo& p_property) const {
if (p_property.name == StringName("facing_weight")) {
if (!include_facing) {
p_property.usage = PROPERTY_USAGE_NO_EDITOR;
}
}

if (p_property.name == StringName("height_weight")) {
if (!include_height) {
p_property.usage = PROPERTY_USAGE_NO_EDITOR;
}
}
}

void MMTrajectoryFeature::_bind_methods() {
ClassDB::bind_method(D_METHOD("get_trajectory_points", "character_transform", "trajectory_data"), &MMTrajectoryFeature::get_trajectory_points);
BINDER_PROPERTY_PARAMS(MMTrajectoryFeature, Variant::FLOAT, past_delta_time);
BINDER_PROPERTY_PARAMS(MMTrajectoryFeature, Variant::INT, past_frames);
BINDER_PROPERTY_PARAMS(MMTrajectoryFeature, Variant::FLOAT, future_delta_time);
BINDER_PROPERTY_PARAMS(MMTrajectoryFeature, Variant::INT, future_frames);
BINDER_PROPERTY_PARAMS(MMTrajectoryFeature, Variant::BOOL, include_height);
BINDER_PROPERTY_PARAMS(MMTrajectoryFeature, Variant::FLOAT, height_weight);
BINDER_PROPERTY_PARAMS(MMTrajectoryFeature, Variant::BOOL, include_facing);
BINDER_PROPERTY_PARAMS(MMTrajectoryFeature, Variant::FLOAT, facing_weight);
}

uint32_t MMTrajectoryFeature::_get_point_dimension_count() const {
Expand Down
17 changes: 15 additions & 2 deletions src/features/mm_trajectory_feature.h
Original file line number Diff line number Diff line change
Expand Up @@ -26,16 +26,29 @@ class MMTrajectoryFeature : public MMFeature {

virtual void display_data(const Ref<EditorNode3DGizmo>& p_gizmo, const Transform3D p_transform, const float* p_data) const override;

virtual float calculate_normalized_weight(int64_t p_feature_dim) const override;

TypedArray<Dictionary> get_trajectory_points(const Transform3D& p_character_transform, const PackedFloat32Array& p_trajectory_data) const;

GETSET(double, past_delta_time, 0.1);
GETSET(int64_t, past_frames, 1);
GETSET(double, future_delta_time, 0.1);
GETSET(int64_t, future_frames, 5);
GETSET(bool, include_height, false);
GETSET(bool, include_facing, true);

bool include_height{false};
bool get_include_height() const;
void set_include_height(bool value);

GETSET(float, height_weight, 1.0);

bool include_facing{true};
bool get_include_facing() const;
void set_include_facing(bool value);

GETSET(float, facing_weight, 1.0);

protected:
void _validate_property(PropertyInfo& p_property) const;
static void _bind_methods();

private:
Expand Down
26 changes: 15 additions & 11 deletions src/mm_animation_library.cpp
Original file line number Diff line number Diff line change
Expand Up @@ -218,11 +218,13 @@ float MMAnimationLibrary::_compute_feature_costs(int p_pose_index, const PackedF
continue;
}

const float feature_cost =
distance_squared((motion_data.ptr() + start_frame_index + start_feature_index),
(p_query.ptr() + start_feature_index),
feature->get_dimension_count()) *
feature->calculate_normalized_weight();
float feature_cost = 0.f;
for (int64_t feature_dim_index = 0; feature_dim_index < feature->get_dimension_count(); feature_dim_index++) {
feature_cost += distance_squared((motion_data.ptr() + start_frame_index + start_feature_index + feature_dim_index),
(p_query.ptr() + start_feature_index + feature_dim_index),
1) *
feature->calculate_normalized_weight(feature_dim_index);
}

if (p_feature_costs) {
p_feature_costs->get_or_add(feature->get_class(), feature_cost);
Expand Down Expand Up @@ -250,11 +252,13 @@ MMQueryOutput MMAnimationLibrary::_search_naive(const PackedFloat32Array& p_quer
continue;
}

const float feature_cost =
distance_squared((motion_data.ptr() + start_feature_index),
(p_query.ptr() + start_feature_index - start_frame_index),
feature->get_dimension_count()) *
feature->calculate_normalized_weight();
float feature_cost = 0.f;
for (int64_t feature_dim_index = 0; feature_dim_index < feature->get_dimension_count(); feature_dim_index++) {
feature_cost += distance_squared((motion_data.ptr() + start_frame_index + start_feature_index + feature_dim_index),
(p_query.ptr() + start_feature_index + feature_dim_index),
1) *
feature->calculate_normalized_weight(feature_dim_index);
}

feature_costs.get_or_add(feature->get_class(), feature_cost);
pose_cost += feature_cost;
Expand Down Expand Up @@ -302,7 +306,7 @@ MMQueryOutput MMAnimationLibrary::_search_kd_tree(const PackedFloat32Array& p_qu
continue;
}
for (int64_t feature_dim_index = 0; feature_dim_index < feature->get_dimension_count(); feature_dim_index++) {
dimension_weights.push_back(feature->calculate_normalized_weight());
dimension_weights.push_back(feature->calculate_normalized_weight(feature_dim_index));
}
}

Expand Down

0 comments on commit 24582f2

Please sign in to comment.