Skip to content

Commit

Permalink
correction de l'algo kmeans
Browse files Browse the repository at this point in the history
  • Loading branch information
Jovillios committed Oct 22, 2023
1 parent 4763666 commit 2b90601
Showing 1 changed file with 19 additions and 14 deletions.
33 changes: 19 additions & 14 deletions transductive.py
Original file line number Diff line number Diff line change
Expand Up @@ -55,7 +55,6 @@ def w(self, features, centroids, centroid_i, beta):
distances = {}
for centroid in centroids:
distances[centroid] = torch.norm(features - centroids[centroid])

sum_distances = torch.zeros_like(distances[centroid_i])
for centroid in centroids:
sum_distances += torch.exp(-beta * distances[centroid] ** 2)
Expand All @@ -71,24 +70,25 @@ def forward(self, support_features, query_features, n_iter, beta):

# initialize centroids
centroids = {}
for class_label, class_features in support_features.items():
centroids[class_label] = torch.mean(class_features["features"], dim=0)
for class_label, _, class_features in support_output:
centroids[class_label] = torch.mean(class_features, dim=0)

# update centroids
for i in range(n_iter):
for _ in range(n_iter):
new_centroids = {}
for centroid in centroids:
suma = 0
sumb = 0
for class_label, _ , class_feature in support_output:
if class_label == centroid:
suma += 1
suma += class_feature
sumb += 1
for class_label, _ , class_feature in query_output:
if class_label == centroid:
weight = self.w(class_feature, centroids, centroid, beta)
suma += weight * class_feature
sumb += weight
centroids[centroid] = suma / sumb
weight = self.w(class_feature, centroids, centroid, beta)
suma += weight * class_feature
sumb += weight
new_centroids[centroid] = suma / sumb
centroids = new_centroids

# assign query features to centroids with least distance
for class_label, _ , class_feature in query_output:
Expand All @@ -108,15 +108,20 @@ def forward(self, support_features, query_features, n_iter, beta):


def test():
# change seed
torch.manual_seed(10)
kmeans = KMeans()
support_features = {}
query_features = {}
for i in range(5):
support_features[i] = {"features": torch.randn(5, 10), "indices": [0, 1, 2, 3, 4]}
query_features[i] = {"features": torch.randn(5, 10), "indices": [0, 1, 2, 3, 4]}
print(kmeans(support_features, query_features, 10, 2))
support_features[f"ok {i}"] = {"features": 5 * torch.randn(1, 368), "indices": [0]}
query_features[f"ok {i}"] = {"features": 5 * torch.randn(5, 368) + i , "indices": [0, 1, 2, 3, 4]}
print(kmeans(support_features, query_features, 1, 5))


if __name__ == "__main__":
test()





0 comments on commit 2b90601

Please sign in to comment.