-
Notifications
You must be signed in to change notification settings - Fork 2
Commit
This commit does not belong to any branch on this repository, and may belong to a fork outside of the repository.
Merge pull request #13 from mantidproject/37527_deploy_cnn
37527 scripts to run CNN Bragg peaks finding for WISH
- Loading branch information
Showing
8 changed files
with
257 additions
and
64 deletions.
There are no files selected for viewing
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,67 @@ | ||
from mantid.simpleapi import CreatePeaksWorkspace, AddPeak | ||
import numpy as np | ||
|
||
def make_3d_array(ws, ntubes=1520, npix_per_tube=128, ntubes_per_bank=152, nmonitors=5): | ||
""" | ||
Extract Ydata from WISH MatrixWorkspace and shape into 3d array | ||
:param ws: MatrixWorkspace of WISH data with xunit wavelength | ||
:param ntubes: number of tubes in instrument (WISH) | ||
:param npix_per_tube: number of detector pixels in each tube | ||
:param ntubes_per_bank: number of tubes per bank | ||
:param nmonitors: number of monitor spectra (assumed first spectra in ws) | ||
:return: 3d numpy array ntubes x npix per tube x nbins | ||
""" | ||
y = ws.extractY()[nmonitors:,:] # exclude monitors - alternatively load with monitors separate? | ||
nbins = ws.blocksize() # 4451 | ||
y = np.reshape(y, (ntubes, npix_per_tube, nbins))[:,::-1,:] # ntubes x npix x nbins (note flipped pix along tube) | ||
# reverse order of tubes in each bank | ||
nbanks = ntubes//ntubes_per_bank | ||
for ibank in range(0, nbanks): | ||
istart = ibank*ntubes_per_bank | ||
iend = (ibank+1)*ntubes_per_bank | ||
y[istart:iend,:,:] = y[istart:iend,:,:][::-1,:,:] | ||
return y | ||
|
||
|
||
def createPeaksWorkspaceFromIndices(ws, peak_wsname, indices, data): | ||
""" | ||
Create a PeaksWorkspace using indices of peaks found in 3d array. WISH has 4 detectors so put peak in central pixel. | ||
Could add peak using avg QLab for peaks at that lambda in all detectors but peak placed in nearest detector anyway | ||
for more details see https://github.com/mantidproject/mantid/issues/31944 | ||
:param ws: MatrixWorkspace of WISH data with xunit wavelength | ||
:param peak_wsname: Output name of peaks workspace created | ||
:param indices: indices of peaks found in 3d array | ||
:param data: 3d array of data | ||
:return: PeaksWorkspace | ||
""" | ||
ispec = findSpectrumIndex(indices, *data.shape[0:2]) | ||
peaks = CreatePeaksWorkspace(InstrumentWorkspace=ws, NumberOfPeaks=0, | ||
OutputWorkspace=peak_wsname, EnableLogging=False) | ||
for ipk in range(len(ispec)): | ||
wavelength = ws.readX(ispec[ipk])[indices[ipk,2]] | ||
# four detectors per spectrum so use one of the central ones | ||
detIDs = ws.getSpectrum(ispec[ipk]).getDetectorIDs() | ||
idet = (len(detIDs)-1)//2 # pick central pixel | ||
AddPeak(PeaksWorkspace=peaks, RunWorkspace=ws, TOF=wavelength, DetectorID=detIDs[idet], | ||
BinCount=data[indices[ipk,0], indices[ipk,1], indices[ipk,2]], EnableLogging=False) | ||
return peaks | ||
|
||
|
||
def findSpectrumIndex(indices, ntubes=1520, npix_per_tube=128, ntubes_per_bank=152, nmonitors=5): | ||
""" | ||
:param indices: indices of found peaks in 3d array | ||
:param ntubes: number of tubes in instrument (WISH) | ||
:param npix_per_tube: number of detector pixels in each tube | ||
:param ntubes_per_bank: number of tubes per bank | ||
:param nmonitors: number of monitor spectra (assumed first spectra in ws) | ||
:return: list of spectrum indindices of workspace corresponding to indices of found peaks in 3d array | ||
""" | ||
# find bank and then reverse order of tubes in bank | ||
ibank = np.floor(indices[:,0]/ntubes_per_bank) | ||
itube = ibank*ntubes_per_bank + ((ntubes_per_bank-1) - indices[:,0] % ntubes_per_bank) | ||
# flip tube | ||
ipix = (npix_per_tube-1) - indices[:,1] | ||
# get spectrum index | ||
specIndex = np.ravel_multi_index((itube.astype(int), ipix), | ||
dims=(ntubes, npix_per_tube), order='C') + nmonitors | ||
return specIndex.tolist() # so elements are of type int not numpy.int32 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,111 @@ | ||
import torchvision | ||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor | ||
import torch as tc | ||
import warnings | ||
from cnn.WISHDataSets import WISHWorkspaceDataSet | ||
import numpy as np | ||
from bragg_utils import createPeaksWorkspaceFromIndices | ||
from tqdm import tqdm | ||
from Diffraction.single_crystal.base_sx import BaseSX | ||
import time | ||
|
||
class BraggDetectCNN: | ||
""" | ||
Detects Bragg's peaks from a workspace using a pre trained deep learning model created using Faster R-CNN model with a ResNet-50-FPN backbone. | ||
Example usage is as shown below. | ||
#1) Set the path of the .pt file that contains the weights of the pre trained FasterRCNN model | ||
cnn_weights_path = r"" | ||
# 2) Create a peaks workspace containing bragg peaks detected with a confidence greater than conf_threshold | ||
cnn_bragg_peaks_detector = BraggDetectCNN(model_weights_path=cnn_weights_path, batch_size=64, workers=0, iou_threshold=0.001) | ||
cnn_bragg_peaks_detector.find_bragg_peaks(workspace="WISH00042730", conf_threshold=0.0, q_tol=0.05) | ||
""" | ||
|
||
def __init__(self, model_weights_path, batch_size=64, workers=0, iou_threshold=0.001): | ||
""" | ||
:param model_weights_path: Path to the .pt file containing the weights of the pre trained CNN model | ||
:param batch_size: Batch size to be used when loading tube data for inferencing. | ||
:param workers: Number of loader worker processes to do multi-process data loading. workers=0 means data loading in main process | ||
:param iou_threshold: IOU(Intersection Over Union) threshold to filter out overlapping bounding boxes for detected peaks | ||
""" | ||
self.model_weights_path = model_weights_path | ||
self.device = self._select_device() | ||
self.model = self._load_cnn_model_from_weights(self.model_weights_path) | ||
self.batch_size = batch_size | ||
self.workers = workers | ||
self.iou_threshold = iou_threshold | ||
|
||
|
||
def find_bragg_peaks(self, workspace, output_ws_name="CNN_Peaks", conf_threshold=0.0, q_tol=0.05): | ||
""" | ||
Find bragg peaks using the pre trained FasterRCNN model and create a peaks workspace | ||
:param workspace: Workspace name or the object of Workspace from WISH, ex: "WISH0042730" | ||
:param output_ws_name: Name of the peaks workspace | ||
:param conf_threshold: Confidence threshold to filter peaks inferred from RCNN | ||
:param q_tol: qlab tolerance to remove duplicate peaks | ||
""" | ||
start_time = time.time() | ||
data_set, predicted_indices = self._do_cnn_inferencing(workspace) | ||
filtered_indices = predicted_indices[predicted_indices[:, -1] > conf_threshold] | ||
filtered_indices_rounded = np.round(filtered_indices[:, :-1]).astype(int) | ||
peaksws = createPeaksWorkspaceFromIndices(data_set.get_workspace(), output_ws_name, filtered_indices_rounded, data_set.get_ws_as_3d_array()) | ||
for ipk, pk in enumerate(peaksws): | ||
pk.setIntensity(filtered_indices[ipk, -1]) | ||
|
||
#Filter duplicates by qlab | ||
BaseSX.remove_duplicate_peaks_by_qlab(peaksws, q_tol) | ||
data_set.delete_rebunched_ws() | ||
print(f"Bragg peaks finding from FasterRCNN model is completed in {time.time()-start_time} seconds!") | ||
|
||
|
||
def _do_cnn_inferencing(self, workspace): | ||
data_set = WISHWorkspaceDataSet(workspace) | ||
data_loader = tc.utils.data.DataLoader(data_set, batch_size=self.batch_size, shuffle=False, num_workers=self.workers) | ||
self.model.eval() | ||
predicted_indices_with_score = [] | ||
for batch_idx, img_batch in enumerate(tqdm(data_loader, desc="Processing batches of tubes")): | ||
for img_idx, img in enumerate(img_batch): | ||
tube_idx = batch_idx * data_loader.batch_size + img_idx | ||
with tc.no_grad(): | ||
prediction = self.model([img.to(self.device)])[0] | ||
nms_prediction = self._apply_nms(prediction, self.iou_threshold) | ||
for box, score in zip(nms_prediction['boxes'], nms_prediction['scores']): | ||
tof = (box[0]+box[2])/2 | ||
tube_res = (box[1]+box[3])/2 | ||
predicted_indices_with_score.append([tube_idx, tube_res.cpu(), tof.cpu(), score.cpu()]) | ||
return data_set, np.array(predicted_indices_with_score) | ||
|
||
|
||
def _apply_nms(self, orig_prediction, iou_thresh): | ||
keep = torchvision.ops.nms(orig_prediction['boxes'], orig_prediction['scores'], iou_thresh) | ||
final_prediction = orig_prediction | ||
final_prediction['boxes'] = final_prediction['boxes'][keep] | ||
final_prediction['scores'] = final_prediction['scores'][keep] | ||
final_prediction['labels'] = final_prediction['labels'][keep] | ||
return final_prediction | ||
|
||
|
||
def _select_device(self): | ||
if tc.cuda.is_available(): | ||
print("GPU device is found!") | ||
return tc.device("cuda") | ||
else: | ||
warnings.warn( | ||
"Warning! GPU is not available, the program will run very slow..", RuntimeWarning) | ||
return tc.device("cpu") | ||
|
||
|
||
def _load_cnn_model_from_weights(self, weights_path): | ||
model = self._get_fasterrcnn_resnet50_fpn(num_classes=2) | ||
model.load_state_dict(tc.load(weights_path, map_location=self.device)) | ||
return model.to(self.device) | ||
|
||
|
||
def _get_fasterrcnn_resnet50_fpn(self, num_classes=2): | ||
model = torchvision.models.detection.fasterrcnn_resnet50_fpn(weights=None) | ||
in_features = model.roi_heads.box_predictor.cls_score.in_features | ||
# replace the head | ||
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, num_classes) | ||
return model | ||
|
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,23 @@ | ||
Bragg Peaks detection using a Faster RCNN model | ||
================ | ||
|
||
Inorder to use the pretrained Faster RCNN model inside mantid, below steps are required. | ||
|
||
* Install mantid from conda `mamba create -n mantid_cnn -c mantid mantidworkbench` | ||
* Activate the conda environment with `mamba activate mantid_cnn` | ||
* Launch workbench from `workbench` command | ||
* Download the script repository's `scriptrepository\diffraction\WISH` directory as instructed here https://docs.mantidproject.org/nightly/workbench/scriptrepository.html | ||
* Check whether `<local path>\diffraction\WISH` path is available at `Python Script Directories` tab from `File->Manage User Directories`. | ||
* Close the workbench | ||
* From command line, change the directory to the place where the scripts were downloaded ex: `<local path>\diffraction\WISH\bragg-detect\cnn` | ||
* Within the same conda enviroment, install pytorch dependancies by running `pip install -r requirements.txt` | ||
* Install NVIDIA CUDA Deep Neural Network library (cuDNN) by running `conda install -c anaconda cudnn` | ||
* Re-launch workbench from `workbench` command | ||
* Below is an example code snippet to test the code. It will create a peaks workspace with the inferred peaks from the cnn and will do a peak filtering using the q_tol provided using `BaseSX.remove_duplicate_peaks_by_qlab`. | ||
```python | ||
from cnn.BraggDetectCNN import BraggDetectCNN | ||
model_weights = r'path/to/pretrained/fasterrcnn_resnet50_model_weights.pt' | ||
cnn_peaks_detector = BraggDetectCNN(model_weights_path=model_weights, batch_size=64) | ||
cnn_peaks_detector.find_bragg_peaks(workspace='WISH00042730', output_ws_name="CNN_Peaks", conf_threshold=0.0, q_tol=0.05) | ||
``` | ||
* If the above import is not working, check whether the `<local path>\diffraction\WISH` path is listed under `Python Script Directories` tab from `File->Manage User Directories`. |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,49 @@ | ||
import torch as tc | ||
import numpy as np | ||
import albumentations as A | ||
from albumentations.pytorch.transforms import ToTensorV2 | ||
from bragg_utils import make_3d_array | ||
from mantid.simpleapi import Load, Rebunch, Workspace, DeleteWorkspace | ||
|
||
|
||
class WISHWorkspaceDataSet(tc.utils.data.Dataset): | ||
def __init__(self, workspace): | ||
if isinstance(workspace, Workspace): | ||
if workspace.getAxis(0).getUnit().unitID() != "TOF": | ||
raise RuntimeError("Unit of the X-axis is expected to be TOF") | ||
ws = workspace | ||
ws_name = ws.getName() | ||
elif isinstance(workspace, str): | ||
ws = Load(Filename=workspace, OutputWorkspace=workspace, EnableLogging=False) | ||
ws_name = ws | ||
else: | ||
raise RuntimeError("Invalid workspace type - must be Workspace object or a name of a workspace to Load") | ||
|
||
self.rebunched_ws = Rebunch(InputWorkspace=ws, NBunch=3, OutputWorkspace=f"__{ws_name}_cnn_rebunched", EnableLogging=False) | ||
self.ws_3d = make_3d_array(self.rebunched_ws) | ||
print(f"Data set for {workspace} is created with shape{self.ws_3d.shape}") | ||
self.trans = A.Compose([A.pytorch.transforms.ToTensorV2(p=1.0)]) | ||
|
||
def get_workspace(self): | ||
if self.rebunched_ws is None: | ||
raise RuntimeError("Rebunched workspace is not available!") | ||
return self.rebunched_ws | ||
|
||
def delete_rebunched_ws(self): | ||
DeleteWorkspace(Workspace=self.rebunched_ws) | ||
self.rebunched_ws = None | ||
|
||
def get_ws_as_3d_array(self): | ||
return self.ws_3d | ||
|
||
def __len__(self): | ||
""" | ||
Return number of tubes in the ws | ||
""" | ||
return self.ws_3d.shape[0] | ||
|
||
def __getitem__(self, idx): | ||
tube_data = self.ws_3d[idx, ...] | ||
tube_data = np.tile(tube_data[:,:, None], 3).astype(np.float32) | ||
frame = self.trans(image=tube_data) | ||
return frame['image'] |
Empty file.
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters
Original file line number | Diff line number | Diff line change |
---|---|---|
@@ -0,0 +1,6 @@ | ||
-f https://download.pytorch.org/whl/cu118 | ||
torch | ||
torchvision | ||
|
||
albumentations==1.4.0 | ||
tqdm==4.66.2 |
This file contains bidirectional Unicode text that may be interpreted or compiled differently than what appears below. To review, open the file in an editor that reveals hidden Unicode characters.
Learn more about bidirectional Unicode characters