Cannot save semantic segmentation predictions as GeoJSON (rastervision+lightninig) #1741
-
Issue: semantic segmentation predictions are not output in proper format I have written a semantic segmentation script very similar to the rastervision+lightning example on the tutorials webpage. Training, evaluation, and prediction all work well with no errors. My only issue is figuring out how to save my predictions as a GeoJSON file so I can overlay them on my input image in GIS software. All of my code is copied below, but I believe the most relevant piece is my import albumentations as A
from rastervision.pytorch_learner import (
SemanticSegmentationRandomWindowGeoDataset,
SemanticSegmentationSlidingWindowGeoDataset,
SemanticSegmentationVisualizer)
from rastervision.core.data import ClassConfig
from tqdm.autonotebook import tqdm
import torch
from torch.utils.data import DataLoader
from torch.nn import functional as F
from torchvision.models.segmentation import deeplabv3_resnet50
import pytorch_lightning as pl
from rastervision.pipeline.file_system import make_dir
from pytorch_lightning.loggers import CSVLogger, TensorBoardLogger
from rastervision.core.data import SemanticSegmentationLabels
from rastervision.core.data.label_store.semantic_segmentation_label_store_config import PolygonVectorOutputConfig
import pathlib
class SemanticSegmentation(pl.LightningModule):
def __init__(self, deeplab, lr=1e-4):
super().__init__()
self.deeplab = deeplab
self.lr = lr
def forward(self, img):
return self.deeplab(img)['out']
def training_step(self, batch, batch_idx):
img, mask = batch
img = img.float()
mask = mask.long()
out = self.forward(img)
loss = F.cross_entropy(out, mask)
log_dict = {'train_loss': loss}
self.log_dict(log_dict, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def validation_step(self, batch, batch_idx):
img, mask = batch
img = img.float()
mask = mask.long()
out = self.forward(img)
loss = F.cross_entropy(out, mask)
log_dict = {'validation_loss': loss}
self.log_dict(log_dict, on_step=True, on_epoch=True, prog_bar=True, logger=True)
return loss
def configure_optimizers(self):
optimizer = torch.optim.Adam(
self.parameters(), lr=self.lr)
return optimizer
class RVLightningTraining:
def __init__(self, tr_uris, val_uris, pred_uris, cc_names, cc_colors):
self.train_uris = tr_uris
self.val_uris = val_uris
self.pred_uris = pred_uris
self.cc = ClassConfig(names=cc_names, colors=cc_colors, null_class="null")
def build_train_ds(self):
data_augmentation_transform = A.Compose([
A.Flip(),
A.ShiftScaleRotate(),
A.RGBShift()
])
train_ds = SemanticSegmentationRandomWindowGeoDataset.from_uris(
class_config=self.cc,
image_uri=self.train_uris[0],
aoi_uri=self.train_uris[2],
label_vector_uri=self.train_uris[1],
label_vector_default_class_id=self.cc.get_class_id('DF'),
size_lims=(300, 350),
out_size=325,
max_windows=10,
transform=data_augmentation_transform)
return train_ds
def build_val_ds(self):
val_ds = SemanticSegmentationSlidingWindowGeoDataset.from_uris(
class_config=self.cc,
image_uri=self.val_uris[0],
aoi_uri=self.val_uris[2],
label_vector_uri=self.val_uris[1],
label_vector_default_class_id=self.cc.get_class_id('DF'),
size=325,
stride=325)
return val_ds
def build_pred_ds(self):
def get_tiles_from_dir(d):
v = [x.as_posix() for x in pathlib.Path(d).glob("*.tif")]
return v[:10]
pred_ds = SemanticSegmentationSlidingWindowGeoDataset.from_uris(
class_config=self.cc,
#image_uri=self.pred_uris[0],
image_uri=get_tiles_from_dir(self.pred_uris[0]),
aoi_uri=self.pred_uris[1],
size=325,
stride=325)
return pred_ds
def train(self):
batch_size = 8
lr = 1e-4
epochs = 3
output_dir = './semseg-trees-lightning/'
make_dir(output_dir)
fast_dev_run = False
deeplab = deeplabv3_resnet50(num_classes=len(self.cc) + 1)
model = SemanticSegmentation(deeplab, lr=lr)
tb_logger = TensorBoardLogger(save_dir=output_dir + "tensorboard", flush_secs=10)
trainer = pl.Trainer(
accelerator='auto',
min_epochs=1,
max_epochs=epochs+1,
default_root_dir=output_dir + "trainer/",
logger=[tb_logger],
fast_dev_run=fast_dev_run,
log_every_n_steps=1,
)
tds = self.build_train_ds()
vds = self.build_val_ds()
train_dl = DataLoader(tds, batch_size=batch_size, shuffle=True, num_workers=4)
val_dl = DataLoader(vds, batch_size=batch_size, num_workers=4)
trainer.fit(model, train_dl, val_dl)
trainer.save_checkpoint(output_dir + "trainer/final-model.ckpt")
def prediction_iterator(self, pred_dl, model, n=None):
for i, (x, _) in tqdm(enumerate(pred_dl)):
if n is not None and i >= n:
break
with torch.inference_mode():
out_batch = model(x)
for out in out_batch:
yield out.numpy()
def predict(self, n=5):
def fix_keys(data):
for _ in range(len(data)):
k, v = data.popitem(False)
newk = k[k.index(".")+1:]
data[newk] = v
return data
ckpt_path = './semseg-trees-lightning/trainer/final-model.ckpt'
ckpt = torch.load(ckpt_path)
state_dict = ckpt["state_dict"]
state_dict = fix_keys(state_dict)
deeplab = deeplabv3_resnet50(num_classes=len(self.cc) + 1)
deeplab.load_state_dict(state_dict)
model = SemanticSegmentation(deeplab)
pds = self.build_pred_ds()
pred_dl = DataLoader(pds, batch_size=8, num_workers=4)
predictions = self.prediction_iterator(pred_dl, model, n)
pred_labels = SemanticSegmentationLabels.from_predictions(
pds.windows,
predictions,
smooth=True,
extent=pds.scene.extent,
num_classes=len(self.cc)+1
)
# PROBABLY MOST RELEVANT PART HERE
pred_labels.save(
uri="./pred-labels-scores",
crs_transformer=pds.scene.raster_source.crs_transformer,
class_config=self.cc,
vector_outputs=[
PolygonVectorOutputConfig(class_id=i) for i in range(len(self.cc))
],
)
return pred_labels
def test():
train_image_uri = "train/230OG_Orthomosaic_export_WedMar08181810258552.tif"
train_label_uri = "train/Tree_polygons_00_230OG.geojson"
train_aoi_uri = "train/border_polygon_00_230OG.geojson"
val_image_uri = "val/230OG_Orthomosaic_export_WedMar08181810258552.tif"
val_label_uri = "val/Tree_polygons_00_230OG.geojson"
val_aoi_uri = "val/border_polygon_00_230OG.geojson"
pred_image_dir = "pred/230OG_Orthomosaic"
pred_aoi_uri = "pred/border_polygon_00_230OG.geojson"
obj = RVLightningTraining(
[train_image_uri, train_label_uri, train_aoi_uri],
[val_image_uri, val_label_uri, val_aoi_uri],
[pred_image_dir, pred_aoi_uri],
["DF", "null"],
["orange", "black"]
)
obj.train()
return obj.predict()
if __name__ == "__main__":
test() Is it possible to output GeoJSON polygons with this configuration? Any help would be much appreciated. Thanks |
Beta Was this translation helpful? Give feedback.
Replies: 2 comments 12 replies
-
Also, why is my score.tif file 40GB while my labels.tif is 3.5GB? |
Beta Was this translation helpful? Give feedback.
-
It looks like you need to also specify As for controlling the TIFF outputs, see the docs for SemanticSegmentationSmoothLabels.save. Here's a snippet that should fix the issue: from os.path import join
pred_uri = "./pred-labels-scores"
pred_labels.save(
uri=pred_uri,
crs_transformer=pds.scene.raster_source.crs_transformer,
class_config=self.cc,
# set to False to skip writing `labels.tif`
discrete_output=True,
# set to False to skip writing `scores.tif`
smooth_output=True,
# set to True to quantize floating point score values to uint8 in scores.tif to reduce file size
smooth_as_uint8=False,
vector_outputs=[
PolygonVectorOutputConfig(class_id=i, uri=join(pred_uri, f'class-{i}.json'))
for i in range(len(self.cc))
],
) |
Beta Was this translation helpful? Give feedback.
It looks like you need to also specify
uri
in eachPolygonVectorOutputConfig
for it to work; otherwise, RV will just logSkipping VectorOutputConfig at index 0 due to missing uri.
etc. and do nothing. I don't think this is ideal. Thanks for reporting this.As for controlling the TIFF outputs, see the docs for SemanticSegmentationSmoothLabels.save.
Here's a snippet that should fix the issue: