Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[BUG] Kernel dies while training segger model #90

Open
marzaidi opened this issue Feb 22, 2025 · 0 comments
Open

[BUG] Kernel dies while training segger model #90

marzaidi opened this issue Feb 22, 2025 · 0 comments
Assignees

Comments

@marzaidi
Copy link

marzaidi commented Feb 22, 2025

hi! i am following the tutorial on https://elihei2.github.io/segger_dev/notebooks/segger_tutorial/, this is my code:

xenium_data_dir = Path('/output-XETG00050__0018897__H-7412-21-1A__20240614__101941')
segger_data_dir = Path('/segger_output')

sample = STSampleParquet(
base_dir=xenium_data_dir,
n_workers=4,
sample_type='xenium', # this could be 'xenium_v2' in case one uses the cell boundaries from the segmentation kit.
# weights=gene_celltype_abundance_embedding, # uncomment if gene-celltype embeddings are available
)

sample.save(
data_dir=segger_data_dir,
k_bd=3,
dist_bd=15.0,
k_tx=3,
dist_tx=5.0,
tile_width=120,
tile_height=120,
neg_sampling_ratio=5.0,
frac=1.0,
val_prob=0.1,
test_prob=0.2,
)

models_dir = Path('/models')

dm = SeggerDataModule(
data_dir=segger_data_dir,
batch_size=2,
num_workers=2,
)

dm.setup()

is_token_based = True
num_tx_tokens = 500

num_bd_features = dm.train[0].x_dict["bd"].shape[1]

ls = LitSegger(
is_token_based = is_token_based,
num_node_features = {"tx": num_tx_tokens, "bd": num_bd_features},
init_emb=8,
hidden_channels=64,
out_channels=16,
heads=4,
num_mid_layers=1,
aggr='sum',
)

trainer = Trainer(
accelerator='cpu',
strategy='auto',
precision='16-mixed',
devices=1, # set higher number if more gpus are available
max_epochs=100,
default_root_dir=models_dir,
logger=CSVLogger(models_dir),
)

trainer.fit(
model=ls,
datamodule=dm
)

i am not sure what is causing the issue here, but while training the model (which takes hours) the kernel dies or just stops because it takes so long and the connection is lost. is there anything i can change in my code to speed up the training?

thank you in advance

Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment
Labels
None yet
Projects
None yet
Development

No branches or pull requests

3 participants