diff --git a/jest/fixtures/models.ts b/jest/fixtures/models.ts index 0dd41aa..e9f6147 100644 --- a/jest/fixtures/models.ts +++ b/jest/fixtures/models.ts @@ -62,6 +62,7 @@ export const mockBasicModel: Model = { name: 'Test Model 1', author: 'test-author', type: 'Test Model Type', + description: 'Test model description', size: 2 * 10 ** 9, params: 2 * 10 ** 9, isDownloaded: false, diff --git a/src/screens/ModelsScreen/HFModelSearch/DetailsView/ModelFileCard/ModelFileCard.tsx b/src/screens/ModelsScreen/HFModelSearch/DetailsView/ModelFileCard/ModelFileCard.tsx index 545ced3..16c03f2 100644 --- a/src/screens/ModelsScreen/HFModelSearch/DetailsView/ModelFileCard/ModelFileCard.tsx +++ b/src/screens/ModelsScreen/HFModelSearch/DetailsView/ModelFileCard/ModelFileCard.tsx @@ -44,7 +44,7 @@ export const ModelFileCard: FC = observer( const isBookmarked = computed(() => modelStore.models.some( model => - model.origin === ModelOrigin.HF && + //model.origin === ModelOrigin.HF && model.hfModelFile?.oid === modelFile.oid, ), ).get(); @@ -52,9 +52,8 @@ export const ModelFileCard: FC = observer( const isDownloaded = computed(() => modelStore.models.some( model => - model.origin === ModelOrigin.HF && - model.hfModelFile?.oid === modelFile.oid && - model.isDownloaded, + //model.origin === ModelOrigin.HF && + model.hfModelFile?.oid === modelFile.oid && model.isDownloaded, ), ).get(); @@ -69,7 +68,9 @@ export const ModelFileCard: FC = observer( const model = modelStore.models.find( (m: Model) => m.hfModelFile?.oid === modelFile.oid, ); - if (model && model.isDownloaded) { + if (model && model.origin === ModelOrigin.PRESET) { + Alert.alert('Cannot Remove', 'The model is preset.'); + } else if (model && model.isDownloaded) { Alert.alert( 'Cannot Remove', 'The model is downloaded. Please delete the file first.', diff --git a/src/screens/ModelsScreen/ModelCard/ModelCard.tsx b/src/screens/ModelsScreen/ModelCard/ModelCard.tsx index af7b885..ffd2fff 100644 --- a/src/screens/ModelsScreen/ModelCard/ModelCard.tsx +++ b/src/screens/ModelsScreen/ModelCard/ModelCard.tsx @@ -1,5 +1,5 @@ import React, {useCallback, useState, useEffect} from 'react'; -import {Alert, Linking, View, Image} from 'react-native'; +import {Alert, Linking, View} from 'react-native'; import {observer} from 'mobx-react-lite'; import {useNavigation} from '@react-navigation/native'; @@ -28,7 +28,10 @@ import {uiStore, modelStore} from '../../../store'; import {chatTemplates} from '../../../utils/chat'; import {getModelDescription, L10nContext} from '../../../utils'; -import {validateCompletionSettings} from '../../../utils/modelSettings'; +import { + COMPLETION_PARAMS_METADATA, + validateCompletionSettings, +} from '../../../utils/modelSettings'; import {Model, ModelOrigin, RootDrawerParamList} from '../../../utils/types'; type ChatScreenNavigationProp = DrawerNavigationProp; @@ -100,15 +103,54 @@ export const ModelCard: React.FC = observer( }, []); const handleSaveSettings = useCallback(() => { - const {isValid, errors} = validateCompletionSettings( - tempCompletionSettings, + // Convert string values to numbers where needed + const processedSettings = Object.entries(tempCompletionSettings).reduce( + (acc, [key, value]) => { + const metadata = COMPLETION_PARAMS_METADATA[key]; + if (metadata?.validation.type === 'numeric') { + // Handle numeric conversion + let numValue: number; + if (typeof value === 'string') { + numValue = Number(value); + } else if (typeof value === 'number') { + numValue = value; + } else { + // If it's neither string nor number, treat as invalid. Most probably won't happen. + acc.errors[key] = 'Must be a valid number'; + return acc; + } + + if (Number.isNaN(numValue)) { + acc.errors[key] = 'Must be a valid number'; + } else { + acc.settings[key] = numValue; + } + } else { + // For non-numeric values, keep as is + acc.settings[key] = value; + } + return acc; + }, + {settings: {}, errors: {}} as { + settings: typeof tempCompletionSettings; + errors: Record; + }, ); - if (!isValid) { + // Validate the converted values + const validationResult = validateCompletionSettings( + processedSettings.settings, + ); + const allErrors = { + ...processedSettings.errors, + ...validationResult.errors, + }; + + if (Object.keys(allErrors).length > 0) { Alert.alert( 'Invalid Values', 'Please correct the following:\n' + - Object.entries(errors) + Object.entries(allErrors) .map(([key, msg]) => `• ${key}: ${msg}`) .join('\n'), [{text: 'OK'}], @@ -118,7 +160,7 @@ export const ModelCard: React.FC = observer( // All validations passed, save the settings modelStore.updateModelChatTemplate(model.id, tempChatTemplate); - modelStore.updateCompletionSettings(model.id, tempCompletionSettings); + modelStore.updateCompletionSettings(model.id, processedSettings.settings); handleCloseSettings(); }, [ model.id, @@ -300,12 +342,6 @@ export const ModelCard: React.FC = observer( isActiveModel && {backgroundColor: theme.colors.tertiaryContainer}, {borderColor: theme.colors.primary}, ]}> - {isHfModel && ( - - )} @@ -327,45 +363,52 @@ export const ModelCard: React.FC = observer( {getModelDescription(model, isActiveModel, modelStore)} + {model.description && ( + + Skills: + {model.description} + + )} - - {/* Display warning icon if there's a memory warning */} - {shortMemoryWarning && isDownloaded && ( - - - + + + {shortMemoryWarning} + + + )} + + {isDownloading && ( + <> + - {shortMemoryWarning} - - - )} - - {isDownloading && ( - <> - - {model.downloadSpeed && ( - - {model.downloadSpeed} - - )} - - )} + {model.downloadSpeed && ( + + {model.downloadSpeed} + + )} + + )} + + {isDownloaded ? (