-
-
Notifications
You must be signed in to change notification settings - Fork 0
/
utils.py
58 lines (46 loc) · 1.68 KB
/
utils.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
import json
from pathlib import Path
from enum import Enum
class ParticipantStateCols(Enum):
EMAIL = "Email"
FL_CLIENT_INSTALLED = "Fl Client Installed"
PROJECT_APPROVED = "Project Approved"
ADDED_PRIVATE_DATA = "Added Private Data"
ROUND = "Round (current/total)"
MODEL_TRAINING_PROGRESS = "Training Progress"
def read_json(data_path: Path):
with open(data_path) as fp:
data = json.load(fp)
return data
def save_json(data: dict, data_path: Path):
with open(data_path, "w") as fp:
json.dump(data, fp, indent=4)
def create_participant_json_file(
participants: list, total_rounds: int, output_path: Path
):
data = []
for participant in participants:
data.append(
{
ParticipantStateCols.EMAIL.value: participant,
ParticipantStateCols.FL_CLIENT_INSTALLED.value: False,
ParticipantStateCols.PROJECT_APPROVED.value: False,
ParticipantStateCols.ADDED_PRIVATE_DATA.value: False,
ParticipantStateCols.ROUND.value: f"0/{total_rounds}",
ParticipantStateCols.MODEL_TRAINING_PROGRESS.value: "N/A",
}
)
save_json(data=data, data_path=output_path)
def update_json(
data_path: Path,
participant_email: str,
column_name: ParticipantStateCols,
column_val: str,
):
if column_name not in ParticipantStateCols:
return
participant_history = read_json(data_path=data_path)
for participant in participant_history:
if participant[ParticipantStateCols.EMAIL.value] == participant_email:
participant[column_name.value] = column_val
save_json(participant_history, data_path)