Skip to content

Commit

Permalink
Merge pull request #1570 from danforthcenter/kmeans-grayscale
Browse files Browse the repository at this point in the history
Kmeans grayscale
  • Loading branch information
nfahlgren authored Nov 4, 2024
2 parents 92aea4c + 5c207f3 commit 2a2d811
Show file tree
Hide file tree
Showing 27 changed files with 70 additions and 37 deletions.
11 changes: 5 additions & 6 deletions docs/kmeans_classifier.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Classification using a kmeans cluster model

The first function (`pcv.predict_kmeans`) takes a target image and uses a trained kmeans model produced by [`pcv.learn.train_kmeans`](train_kmeans.md) to classify regions of the target image by the trained clusters. The second function (`pcv.mask_kmeans`) takes a list of clusters and produces the combined mask from clusters of interest.
The first function (`pcv.predict_kmeans`) takes a target image and uses a trained kmeans model produced by [`pcv.learn.train_kmeans`](train_kmeans.md) to classify regions of the target image by the trained clusters. The second function (`pcv.mask_kmeans`) takes a list of clusters and produces the combined mask from clusters of interest. The target and training images may be in grayscale or RGB image format.

**plantcv.kmeans_classifier.predict_kmeans**(img, model_path="./kmeansout.fit", patch_size=10)

Expand All @@ -18,14 +18,13 @@ The first function (`pcv.predict_kmeans`) takes a target image and uses a traine
- **Example use below**


**plantcv.kmeans_classifier.mask_kmeans**(labeled_img, k, patch_size, cat_list=None)
**plantcv.kmeans_classifier.mask_kmeans**(labeled_img, k, cat_list=None)

**outputs** Either a combined mask of the requestedlist of clusters or a dictionary of each cluster as a separate mask with keys corresponding to the cluster number

- **Parameters:**
- labeled_img = The output from predict_kmeans, an image with pixels labeled according to their cluster assignment
- k = The number of clusters in the trained model
- patch_size = Size of the NxN neighborhood around each pixel, used for classification
- cat_list = List of clusters to include in a combined mask. If None, output is a dictionary of separate masks for each cluster

- **Context:**
Expand All @@ -46,9 +45,9 @@ labeled_img = pcv.predict_kmeans(img='./leaf_example.png',
model_path="./kmeansout_leaf.fit", patch_size=5)

#Choosing clusters for each category within the seed image
background = pcv.mask_kmeans(labeled_img=labeled_img, k=10, patch_size=5, cat_list=[0, 2, 4, 6, 7])
sick = pcv.mask_kmeans(labeled_img=labeled_img, k=10, patch_size=5, cat_list=[1, 3])
leaf = pcv.mask_kmeans(labeled_img=labeled_img, k=10, patch_size=5, cat_list=[5, 8, 9])
background = pcv.mask_kmeans(labeled_img=labeled_img, k=10 cat_list=[0, 2, 4, 6, 7])
sick = pcv.mask_kmeans(labeled_img=labeled_img, k=10, cat_list=[1, 3])
leaf = pcv.mask_kmeans(labeled_img=labeled_img, k=10, cat_list=[5, 8, 9])

```

Expand Down
2 changes: 1 addition & 1 deletion docs/train_kmeans.md
Original file line number Diff line number Diff line change
@@ -1,6 +1,6 @@
## Kmeans clustering training

This function takes in a collection of training images and fits a patch-based kmeans cluster model for later use in classifying cluster assignment in a target image.
This function takes in a collection of training images and fits a patch-based kmeans cluster model for later use in classifying cluster assignment in a target image. The target and training images may be in grayscale or RGB image format.

**plantcv.learn.train_kmeans**(img_dir, k, out_path="./kmeansout.fit", prefix="", patch_size=10, sigma=5, sampling=None, seed=1, num_imgs=0, n_init=10)

Expand Down
14 changes: 8 additions & 6 deletions plantcv/learn/train_kmeans.py
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,15 @@ def train_kmeans(img_dir, k, out_path="./kmeansout.fit", prefix="", patch_size=1
else:
training_files = random.choices(file_names, k=num_imgs) # choosing a set of random files
# Read and extract patches
i = 0
for img_name in training_files:
for idx, img_name in enumerate(training_files):
if prefix in img_name:
img = cv2.imread(os.path.join(img_dir, img_name))
if i == 0:
img = cv2.imread(os.path.join(img_dir, img_name), -1)
if idx == 0:
# Getting info from first image
patches = patch_extract(img, patch_size=patch_size, sigma=sigma, sampling=sampling)
else:
# Concatenating each additional image
patches = np.vstack((patches, patch_extract(img, patch_size=patch_size, sigma=sigma, sampling=sampling)))
i += 1

kmeans = MiniBatchKMeans(n_clusters=k, n_init=n_init, random_state=seed)
fitted = kmeans.fit(patches)
Expand All @@ -79,7 +77,11 @@ def patch_extract(img, patch_size=10, sigma=5, sampling=None, seed=1):
:return patches_lin: numpy.ndarray
"""
# Gaussian blur
img_blur = np.round(gaussian(img, sigma=sigma, channel_axis=2)*255).astype(np.uint16)
if len(img.shape) == 2:
img_blur = np.round(gaussian(img, sigma=sigma)*255).astype(np.uint16)
elif len(img.shape) == 3 and img.shape[2] == 3:
img_blur = np.round(gaussian(img, sigma=sigma, channel_axis=2)*255).astype(np.uint16)

# Extract patches
patches = image.extract_patches_2d(img_blur, (patch_size, patch_size),
max_patches=sampling, random_state=seed)
Expand Down
46 changes: 27 additions & 19 deletions plantcv/plantcv/kmeans_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -24,9 +24,23 @@ def predict_kmeans(img, model_path="./kmeansout.fit", patch_size=10):
kmeans = load(model_path)
train_img, _, _ = pcv.readimage(img)

before = after = int((patch_size - 1)/2) # odd
if patch_size % 2 == 0: # even
before = int((patch_size-2)/2)
after = int(patch_size/2)

# Padding
if len(train_img.shape) == 2: # gray
train_img = np.pad(train_img, pad_width=((before, after), (before, after)), mode="edge")
elif len(train_img.shape) == 3 and train_img.shape[2] == 3: # rgb
train_img = np.pad(train_img, pad_width=((before, after), (before, after), (0, 0)), mode="edge")

# Shapes
mg = np.floor(patch_size / 2).astype(np.int32)
h, w, _ = train_img.shape
if len(train_img.shape) == 2:
h, w = train_img.shape
elif len(train_img.shape) == 3 and train_img.shape[2] == 3:
h, w, _ = train_img.shape

# Do the prediction
train_patches = patch_extract(train_img, patch_size=patch_size)
Expand All @@ -38,7 +52,7 @@ def predict_kmeans(img, model_path="./kmeansout.fit", patch_size=10):
return labeled


def mask_kmeans(labeled_img, k, patch_size=10, cat_list=None):
def mask_kmeans(labeled_img, k, cat_list=None):
"""
Uses the predicted clusters from a target image to generate a binary mask.
Inputs:
Expand All @@ -52,29 +66,23 @@ def mask_kmeans(labeled_img, k, patch_size=10, cat_list=None):
:param patch_size: positive non-zero integer
:param cat_list: list of positive non-zero integers
"""
mg = np.floor(patch_size / 2).astype(np.int32)
h, w = labeled_img.shape
if cat_list is None:
mask_dict = {}
L = [*range(k)]
for i in L:
mask = np.ones(labeled_img.shape)
mask = np.logical_and(mask, labeled_img != i)
mask[:, 0:mg] = False
mask[:, w-mg:w] = False
mask[0:mg, :] = False
mask[h-mg:h, :] = False
mask_light = abs(1-mask)
mask_light = np.where(labeled_img == i, 255, 0)
_debug(visual=mask_light, filename=os.path.join(params.debug_outdir, "_kmeans_mask_"+str(i)+".png"))
mask_dict[str(i)] = mask_light
return mask_dict
mask = np.ones(labeled_img.shape)
for label in cat_list:
mask = np.logical_and(mask, labeled_img != label)
mask[:, 0:mg] = False
mask[:, w-mg:w] = False
mask[0:mg, :] = False
mask[h-mg:h, :] = False
mask_light = abs(1-mask)
# Store debug
debug = params.debug
# Change to None so that logical_or does not plot each stepwise addition
params.debug = None
for idx, i in enumerate(cat_list):
if idx == 0:
mask_light = np.where(labeled_img == i, 255, 0)
else:
mask_light = pcv.logical_or(mask_light, np.where(labeled_img == i, 255, 0))
params.debug = debug
_debug(visual=mask_light, filename=os.path.join(params.debug_outdir, "_kmeans_combined_mask.png"))
return mask_light
2 changes: 2 additions & 0 deletions tests/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -82,6 +82,8 @@ def __init__(self):
self.cluster_names_too_many = os.path.join(self.datadir, "cluster_names_too_many.txt")
# Kmeans classifier directory
self.kmeans_classifier_dir = os.path.join(self.datadir, "kmeans_classifier_dir")
# Kmeans classifier grayscale directory
self.kmeans_classifier_gray_dir = os.path.join(self.datadir, "kmeans_classifier_gray_dir")

@staticmethod
def load_hsi(pkl_file):
Expand Down
2 changes: 2 additions & 0 deletions tests/learn/conftest.py
Original file line number Diff line number Diff line change
Expand Up @@ -17,6 +17,8 @@ def __init__(self):
self.rgb_values_table = os.path.join(self.datadir, "rgb_values_table.txt")
# Kmeans training directory
self.kmeans_train_dir = os.path.join(self.datadir, "kmeans_train_dir")
# Kmeans training grayscale directory
self.kmeans_train_gray_dir = os.path.join(self.datadir, "kmeans_train_gray_dir")


@pytest.fixture(scope="session")
Expand Down
13 changes: 11 additions & 2 deletions tests/learn/test_kmeans_train.py
Original file line number Diff line number Diff line change
Expand Up @@ -7,12 +7,21 @@ def test_train_kmeans(learn_test_data, tmpdir):
# Create a test tmp directory
cache_dir = tmpdir.mkdir("cache")
training_dir = learn_test_data.kmeans_train_dir
training_dir_gray = learn_test_data.kmeans_train_gray_dir
outfile_subset = os.path.join(str(cache_dir), "kmeansout_subset.fit")
outfile_full = os.path.join(str(cache_dir), "kmeansout_full.fit")
outfile_subset_gray = os.path.join(str(cache_dir), "kmeansout_subset_gray.fit")
outfile_full_gray = os.path.join(str(cache_dir), "kmeansout_full_gray.fit")
# Train full model and partial model
train_kmeans(img_dir=training_dir, prefix="kmeans_train",
out_path=outfile_subset, k=5, num_imgs=3)
out_path=outfile_subset, k=5, patch_size=4, num_imgs=3)
train_kmeans(img_dir=training_dir, prefix="kmeans_train",
out_path=outfile_full, k=5)
out_path=outfile_full, k=5, patch_size=4)
train_kmeans(img_dir=training_dir_gray, prefix="kmeans_train",
out_path=outfile_subset_gray, k=5, patch_size=4, num_imgs=3)
train_kmeans(img_dir=training_dir_gray, prefix="kmeans_train",
out_path=outfile_full_gray, k=5, patch_size=4)
assert os.path.exists(outfile_subset)
assert os.path.exists(outfile_full)
assert os.path.exists(outfile_subset_gray)
assert os.path.exists(outfile_full_gray)
17 changes: 14 additions & 3 deletions tests/plantcv/test_kmeans_classifier.py
Original file line number Diff line number Diff line change
Expand Up @@ -6,14 +6,25 @@
def test_kmeans_classifier(test_data):
"""Test for PlantCV."""
input_dir = test_data.kmeans_classifier_dir
labeled_img = predict_kmeans(img=input_dir+"/test_image.jpg", model_path=input_dir+"/kmeans_out.fit", patch_size=5)
input_dir_gray = test_data.kmeans_classifier_gray_dir
labeled_img = predict_kmeans(img=input_dir+"/test_image.jpg", model_path=input_dir+"/kmeans_out.fit", patch_size=4)
labeled_img_gray = predict_kmeans(img=input_dir_gray+"/test_image_gray.jpg",
model_path=input_dir_gray+"/kmeans_out_gray.fit", patch_size=4)
test_labeled, _, _ = readimage(input_dir+"/labeled_image.png")
test_labeled_gray, _, _ = readimage(input_dir_gray+"/labeled_image_gray.png")
assert (labeled_img == test_labeled).all()
assert (labeled_img_gray == test_labeled_gray).all()

mask_dict = mask_kmeans(labeled_img=labeled_img, k=4, patch_size=5)
mask_dict = mask_kmeans(labeled_img=labeled_img, k=4)
mask_dict_gray = mask_kmeans(labeled_img=labeled_img_gray, k=4)
for i in range(4):
assert (readimage(input_dir+"/label_example_"+str(i)+".png")[0] == mask_dict[str(i)]).all()
for i in range(4):
assert (readimage(input_dir_gray+"/label_example_gray_"+str(i)+".png")[0] == mask_dict_gray[str(i)]).all()

combo_mask = mask_kmeans(labeled_img=labeled_img, k=4, patch_size=5, cat_list=[1, 2])
combo_mask = mask_kmeans(labeled_img=labeled_img, k=4, cat_list=[1, 2])
combo_mask_gray = mask_kmeans(labeled_img=labeled_img_gray, k=4, cat_list=[1, 2])
combo_example, _, _ = readimage(input_dir+"/combo_mask_example.png")
combo_example_gray, _, _ = readimage(input_dir_gray+"/combo_mask_example_gray.png")
assert (combo_mask == combo_example).all()
assert (combo_mask_gray == combo_example_gray).all()
Binary file modified tests/testdata/kmeans_classifier_dir/combo_mask_example.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/testdata/kmeans_classifier_dir/kmeans_out.fit
Binary file not shown.
Binary file modified tests/testdata/kmeans_classifier_dir/label_example_0.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/testdata/kmeans_classifier_dir/label_example_1.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/testdata/kmeans_classifier_dir/label_example_2.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/testdata/kmeans_classifier_dir/label_example_3.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file modified tests/testdata/kmeans_classifier_dir/labeled_image.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Binary file not shown.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.

0 comments on commit 2a2d811

Please sign in to comment.