Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

Bugfix/issue 4 predict incorrect channels #16

Merged
merged 9 commits into from
Jul 23, 2024
2 changes: 1 addition & 1 deletion src/components/app-bars/AlertBar/AlertBar.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -154,7 +154,7 @@ export const AlertBar = ({
color: colorTheme.contrastText,
}}
>
<Typography sx={{ pl: 3, pb: 1 }}>
<Typography sx={{ pl: 3, pb: 1, whiteSpace: "pre-wrap" }}>
{alertState.description}
</Typography>
</Box>
Expand Down
2 changes: 2 additions & 0 deletions src/components/cards/ExampleImageCard/ExampleImageCard.tsx
Original file line number Diff line number Diff line change
Expand Up @@ -6,6 +6,7 @@ import { dataConverter_v1v2 } from "utils/file-io/converters/dataConverter_v1v2"
import { dataSlice } from "store/data/dataSlice";
import { SerializedFileType } from "utils/file-io/types";
import { loadExampleImage } from "utils/file-io/loadExampleImage";
import { projectSlice } from "store/project";

type ExampleImageType = {
name: string;
Expand Down Expand Up @@ -50,6 +51,7 @@ export const ExampleImageCard = ({
});

batch(() => {
dispatch(projectSlice.actions.setProjectImageChannels({ channels: 3 }));
dispatch(
dataSlice.actions.initializeState({
data,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -132,6 +132,12 @@ export const ExampleProjectCard = ({

batch(() => {
// loadPercent will be set to 1 here
dispatch(
projectSlice.actions.setProjectImageChannels({
channels: Object.values(data.things.entities)[0].saved.shape
.channels,
})
);
dispatch(dataSlice.actions.initializeState({ data }));
dispatch(projectSlice.actions.setProject({ project }));

Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -77,7 +77,7 @@ export const DialogWithAction = ({
</IconButton>
</Box>

{content && <DialogContent>{content}</DialogContent>}
{content && <DialogContent sx={{ py: 0 }}>{content}</DialogContent>}

<DialogActions>
<Button
Expand Down
152 changes: 0 additions & 152 deletions src/components/dialogs/ImageShapeDialog/ImageShapeDialog.tsx

This file was deleted.

1 change: 0 additions & 1 deletion src/components/dialogs/ImageShapeDialog/index.ts

This file was deleted.

Original file line number Diff line number Diff line change
Expand Up @@ -32,6 +32,8 @@ import { availableClassifierModels } from "utils/models/availableClassificationM
import { availableSegmenterModels } from "utils/models/availableSegmentationModels";
import { HotkeyView } from "utils/common/enums";
import { Shape } from "store/data/types";
import { selectProjectImageChannels } from "store/project/selectors";
import { useSelector } from "react-redux";

const ToolTipTab = (
props: TabProps & {
Expand Down Expand Up @@ -71,18 +73,25 @@ const ToolTipTab = (

type ImportTensorflowModelDialogProps = {
onClose: () => void;
loadedModel?: Model;
open: boolean;
modelTask: ModelTask;
dispatchFunction: (model: Model, inputShape: Shape) => void;
};

export const ImportTensorflowModelDialog = ({
onClose,
loadedModel,
open,
modelTask,
dispatchFunction,
}: ImportTensorflowModelDialogProps) => {
const [selectedModel, setSelectedModel] = useState<Model>();
const projectChannels = useSelector(selectProjectImageChannels);
const [selectedModel, setSelectedModel] = useState<Model | undefined>(
loadedModel?.name === "Fully Convolutional Network"
? undefined
: loadedModel
);
const [inputShape, setInputShape] = useState<Shape>({
height: 256,
width: 256,
Expand All @@ -96,6 +105,7 @@ export const ImportTensorflowModelDialog = ({
const [cloudWarning, setCloudWarning] = useState(false);

const [tabVal, setTabVal] = useState("1");
const [invalidModel, setInvalidModel] = useState(false);

const onModelChange = useCallback((model: Model | undefined) => {
setSelectedModel(model);
Expand All @@ -116,11 +126,15 @@ export const ImportTensorflowModelDialog = ({

dispatchFunction(selectedModel, inputShape);

closeDialog();
setCloudWarning(false);
setInvalidModel(false);
onClose();
};

const closeDialog = () => {
setCloudWarning(false);
setInvalidModel(false);
setSelectedModel(loadedModel);
onClose();
};

Expand Down Expand Up @@ -153,6 +167,16 @@ export const ImportTensorflowModelDialog = ({
);
}, [modelTask]);

useEffect(() => {
if (modelTask === ModelTask.Segmentation) {
if (selectedModel && selectedModel.requiredChannels !== projectChannels) {
setInvalidModel(true);
} else {
setInvalidModel(false);
}
}
}, [modelTask, projectChannels, selectedModel]);

return (
<Dialog fullWidth maxWidth="xs" onClose={closeDialog} open={open}>
<Collapse in={cloudWarning}>
Expand Down Expand Up @@ -212,7 +236,22 @@ export const ImportTensorflowModelDialog = ({
<Box hidden={tabVal !== "1"}>
<PretrainedModelSelector
values={pretrainedModels}
initModel={
selectedModel
? pretrainedModels.findIndex(
(model) => model.name === selectedModel.name
) + ""
: "-1"
}
setModel={onModelChange}
error={invalidModel}
errorText={
!selectedModel
? "Select a Model"
: invalidModel
? `Model requires ${selectedModel.requiredChannels}-channel images`
: ""
}
/>
</Box>

Expand Down Expand Up @@ -246,7 +285,7 @@ export const ImportTensorflowModelDialog = ({
<Button
onClick={dispatchModelToStore}
color="primary"
disabled={!selectedModel}
disabled={!selectedModel || invalidModel}
>
Open{" "}
{modelTask === ModelTask.Classification
Expand Down
Loading
Loading