-
Notifications
You must be signed in to change notification settings - Fork 1
/
gan.py
212 lines (176 loc) · 8.66 KB
/
gan.py
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
149
150
151
152
153
154
155
156
157
158
159
160
161
162
163
164
165
166
167
168
169
170
171
172
173
174
175
176
177
178
179
180
181
182
183
184
185
186
187
188
189
190
191
192
193
194
195
196
197
198
199
200
201
202
203
204
205
206
207
208
209
210
211
import torch
import torch.nn as nn
import torch.optim as optim
from torch.utils.data import DataLoader, Dataset
import torchtext.utils as tutils
import json
# Define the Generator model
class Generator(nn.Module):
def __init__(self, input_dim, output_dim):
super(Generator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, output_dim),
nn.Tanh()
)
def forward(self, x):
return self.model(x)
# Define the Discriminator model
class Discriminator(nn.Module):
def __init__(self, input_dim):
super(Discriminator, self).__init__()
self.model = nn.Sequential(
nn.Linear(input_dim, 256),
nn.LeakyReLU(0.2),
nn.Linear(256, 128),
nn.LeakyReLU(0.2),
nn.Linear(128, 1),
nn.Sigmoid()
)
def forward(self, x):
return self.model(x)
# Custom Dataset class for loading FHIR data in NDJSON format
class FHIRDataset(Dataset):
def __init__(self, file_path):
self.data = []
with open(file_path, encoding='utf8', mode='r') as f:
for line in f:
json_obj = json.loads(line)
self.data.append(json_obj)
def __len__(self):
return len(self.data)
def __getitem__(self, index):
fhir_profile_resource = resource_from_profile(self.data[index].get('resourceType'), fhir_profiles_resources_json)
return fhir_resource_to_tensor(self.data[index], self.data[index].get('resourceType'), fhir_profile_resource, fhir_value_set) # Assuming each line is a tensor
def resource_from_profile(fhir_resource, fhir_profiles_resources):
for i, resource in enumerate(fhir_profiles_resources['entry']):
if resource['resource'].get('id')==fhir_resource:
return resource
#Convert FHIR to Tensor
def fhir_resource_to_tensor(fhir_resource_json, fhir_resource, fhir_profile_resource, fhir_value_set):
# Parse the FHIR resource
#fhir_resource = fhir_types.fhir_resource(fhir_resource)
# Get the list of elements from the StructureDefinition for the current resource type
elements = fhir_profile_resource['resource'].get('differential')['element'][1:] #the first elelemt is the resource itself, so skip that
# Create an empty tensor with the shape of the elements
tensor_shape = (1, len(elements))
output_dim = len(elements)
tensor = torch.empty(tensor_shape)
# Iterate through the elements and populate the tensor
for i, element in enumerate(elements):
fhir_element = element['id'].split('.')[1]
value = fhir_resource_json.get(fhir_element, None)
if value is not None:
if isinstance(value, list):
tensor[0,i] = torch.tensor(len(value))
elif element.get('type')[0].get('code') == 'date':
tensor[0,i] = torch.tensor(len(date_to_one_hot(value)))
elif element.get('type')[0].get('code') == 'CodeableConcept':
tensor[0,i] = torch.tensor(len(value))
else: #if its a value from a valueset, get the index of the value gtom the FHIR valuesets
tensor[0,i] = torch.tensor(get_concept_index_from_codesystem(fhir_value_set, element['binding'].get('valueSet').split('|')[0], value))
else:
tensor[0,i] = -1
return tensor.to(device)
def date_to_one_hot(date):
# Split the date string into year, month, and day components
year, month, day = date.split('-')
# Define the possible values for year, month, and day
years = [str(i) for i in range(1900, 2101)] # You can adjust the range of years as needed
months = [str(i).zfill(2) for i in range(1, 13)]
days_in_month = [str(i).zfill(2) for i in range(1, 32)]
# Create the one-hot encoded vectors for year, month, and day
year_vector = [1 if year == y else 0 for y in years]
month_vector = [1 if month == m else 0 for m in months]
day_vector = [1 if day == d else 0 for d in days_in_month]
# Combine the one-hot encoded vectors into a single vector
one_hot_vector = year_vector + month_vector + day_vector
return one_hot_vector
def get_concept_index_from_codesystem(fhir_value_set, fhir_value_set_url, concept_code):
for entry in fhir_value_set['entry']:
if entry['resource'].get('valueSet') == fhir_value_set_url:
concept_index = 0
for concept in entry['resource'].get('concept'):
if concept.get('code') == concept_code:
return concept_index
concept_index +=1
def train_gan(generator, discriminator, dataloader, num_epochs, device):
# Check if the input data is empty
if len(dataloader.dataset) == 0:
raise ValueError("The input dataloader is empty. Please make sure it contains data.")
# Define loss function and optimizers
criterion = nn.BCELoss()
generator_optimizer = optim.Adam(generator.parameters(), lr=lr)
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=lr)
# Training loop
for epoch in range(num_epochs):
for batch_idx, real_data in enumerate(dataloader):
real_data = real_data.to(device)
# Train discriminator with real data
discriminator.zero_grad()
real_labels = torch.ones(real_data.shape[0], 1, 1).to(device) # Adjust the shape of real_labels
real_output = discriminator(real_data)
real_loss = criterion(real_output, real_labels)
real_loss.backward()
real_cpu = real_data[0].to(device)
# Train discriminator with generated data
noise = torch.randn(real_data.shape[0], input_dim).to(device) # Adjust the shape of the noise
fake_data = generator(noise).detach()
fake_labels = torch.zeros(real_data.shape[0], 1).to(device) # Adjust the shape of fake_labels
fake_output = discriminator(fake_data)
fake_loss = criterion(fake_output, fake_labels)
fake_loss.backward()
discriminator_loss = real_loss + fake_loss
discriminator_optimizer.step()
# Clip discriminator's gradients
for p in discriminator.parameters():
p.data.clamp_(-0.01, 0.01)
# Train generator
generator.zero_grad()
real_labels.fill_(1) # Reset real_labels to 1s for the generator loss
fake_output = discriminator(fake_data)
generator_loss = criterion(fake_output.squeeze(), real_labels.squeeze())
generator_loss.backward()
generator_optimizer.step()
if batch_idx % 100 == 0: # Only print the stats on the batch
print(
f"Epoch [{epoch + 1}/{num_epochs}], "
f"Batch complete with [{real_data.shape[0]} passes], " # Print the actual batch size
f"Discriminator Loss: {discriminator_loss.item():.4f}, "
f"Generator Loss: {generator_loss.item():.4f}")
# Print the generated text after each epoch
generated_text = fake_data[0].detach().cpu().numpy() # Convert tensor to numpy array
print(f"Generated Text: {generated_text}")
# Set input dim
input_dim = 1 # Dimension of the random noise input for the generator
output_dim = 27 # Dimension of the generated output
# Device configuration
device = torch.device("cuda:0" if torch.cuda.is_available() else "cpu")
with open('fhir/valuesets.json', encoding='utf8', mode='r') as f:
fhir_value_set = json.load(f)
with open('fhir/profiles-resources.json', encoding='utf8', mode='r') as f:
fhir_profiles_resources_json = json.load(f)
# Entry point of the script
if __name__ == "__main__":
# Set other training parameters
lr = 0.0002 # Learning rate
batch_size = 1000 # Batch size for training
num_epochs = 200
# Initialize generator and discriminator
generator = Generator(input_dim, output_dim).to(device)
discriminator = Discriminator(output_dim).to(device)
# Load the FHIR dataset
dataset = FHIRDataset('data/Patient.ndjson')
dataloader = DataLoader(dataset, batch_size=batch_size, shuffle=True)
# Define loss function and optimizers
criterion = nn.BCELoss()
generator_optimizer = optim.Adam(generator.parameters(), lr=lr)
discriminator_optimizer = optim.Adam(discriminator.parameters(), lr=lr)
# Train the GAN
train_gan(generator, discriminator, dataloader, num_epochs, device)
# Save trained models
torch.save(generator.state_dict(), 'generator.pth')
torch.save(discriminator.state_dict(), 'discriminator.pth')