Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

[Chore] Add hfModelFile to preset models for alignment with HF-sourced models & shortens the preset model lists #145

Merged
merged 7 commits into from
Dec 23, 2024
1 change: 1 addition & 0 deletions jest/fixtures/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ export const mockBasicModel: Model = {
name: 'Test Model 1',
author: 'test-author',
type: 'Test Model Type',
description: 'Test model description',
size: 2 * 10 ** 9,
params: 2 * 10 ** 9,
isDownloaded: false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,16 @@ export const ModelFileCard: FC<ModelFileCardProps> = observer(
const isBookmarked = computed(() =>
modelStore.models.some(
model =>
model.origin === ModelOrigin.HF &&
//model.origin === ModelOrigin.HF &&
model.hfModelFile?.oid === modelFile.oid,
),
).get();

const isDownloaded = computed(() =>
modelStore.models.some(
model =>
model.origin === ModelOrigin.HF &&
model.hfModelFile?.oid === modelFile.oid &&
model.isDownloaded,
//model.origin === ModelOrigin.HF &&
model.hfModelFile?.oid === modelFile.oid && model.isDownloaded,
),
).get();

Expand All @@ -69,7 +68,9 @@ export const ModelFileCard: FC<ModelFileCardProps> = observer(
const model = modelStore.models.find(
(m: Model) => m.hfModelFile?.oid === modelFile.oid,
);
if (model && model.isDownloaded) {
if (model && model.origin === ModelOrigin.PRESET) {
Alert.alert('Cannot Remove', 'The model is preset.');
} else if (model && model.isDownloaded) {
Alert.alert(
'Cannot Remove',
'The model is downloaded. Please delete the file first.',
Expand Down
135 changes: 89 additions & 46 deletions src/screens/ModelsScreen/ModelCard/ModelCard.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import React, {useCallback, useState, useEffect} from 'react';
import {Alert, Linking, View, Image} from 'react-native';
import {Alert, Linking, View} from 'react-native';

import {observer} from 'mobx-react-lite';
import {useNavigation} from '@react-navigation/native';
Expand Down Expand Up @@ -28,7 +28,10 @@ import {uiStore, modelStore} from '../../../store';

import {chatTemplates} from '../../../utils/chat';
import {getModelDescription, L10nContext} from '../../../utils';
import {validateCompletionSettings} from '../../../utils/modelSettings';
import {
COMPLETION_PARAMS_METADATA,
validateCompletionSettings,
} from '../../../utils/modelSettings';
import {Model, ModelOrigin, RootDrawerParamList} from '../../../utils/types';

type ChatScreenNavigationProp = DrawerNavigationProp<RootDrawerParamList>;
Expand Down Expand Up @@ -100,15 +103,54 @@ export const ModelCard: React.FC<ModelCardProps> = observer(
}, []);

const handleSaveSettings = useCallback(() => {
const {isValid, errors} = validateCompletionSettings(
tempCompletionSettings,
// Convert string values to numbers where needed
const processedSettings = Object.entries(tempCompletionSettings).reduce(
(acc, [key, value]) => {
const metadata = COMPLETION_PARAMS_METADATA[key];
if (metadata?.validation.type === 'numeric') {
// Handle numeric conversion
let numValue: number;
if (typeof value === 'string') {
numValue = Number(value);
} else if (typeof value === 'number') {
numValue = value;
} else {
// If it's neither string nor number, treat as invalid. Most probably won't happen.
acc.errors[key] = 'Must be a valid number';
return acc;
}

if (Number.isNaN(numValue)) {
acc.errors[key] = 'Must be a valid number';
} else {
acc.settings[key] = numValue;
}
} else {
// For non-numeric values, keep as is
acc.settings[key] = value;
}
return acc;
},
{settings: {}, errors: {}} as {
settings: typeof tempCompletionSettings;
errors: Record<string, string>;
},
);

if (!isValid) {
// Validate the converted values
const validationResult = validateCompletionSettings(
processedSettings.settings,
);
const allErrors = {
...processedSettings.errors,
...validationResult.errors,
};

if (Object.keys(allErrors).length > 0) {
Alert.alert(
'Invalid Values',
'Please correct the following:\n' +
Object.entries(errors)
Object.entries(allErrors)
.map(([key, msg]) => `• ${key}: ${msg}`)
.join('\n'),
[{text: 'OK'}],
Expand All @@ -118,7 +160,7 @@ export const ModelCard: React.FC<ModelCardProps> = observer(

// All validations passed, save the settings
modelStore.updateModelChatTemplate(model.id, tempChatTemplate);
modelStore.updateCompletionSettings(model.id, tempCompletionSettings);
modelStore.updateCompletionSettings(model.id, processedSettings.settings);
handleCloseSettings();
}, [
model.id,
Expand Down Expand Up @@ -300,12 +342,6 @@ export const ModelCard: React.FC<ModelCardProps> = observer(
isActiveModel && {backgroundColor: theme.colors.tertiaryContainer},
{borderColor: theme.colors.primary},
]}>
{isHfModel && (
<Image
source={require('../../../assets/icon-hf.png')}
style={styles.hfBadge}
/>
)}
<View style={styles.cardInner}>
<View style={styles.cardContent}>
<View style={styles.headerRow}>
Expand All @@ -327,45 +363,52 @@ export const ModelCard: React.FC<ModelCardProps> = observer(
<Text style={styles.modelDescription}>
{getModelDescription(model, isActiveModel, modelStore)}
</Text>
{model.description && (
<View style={styles.descriptionContainer}>
<Text style={styles.skillsLabel}>Skills: </Text>
<Text style={styles.skillsText}>{model.description}</Text>
</View>
)}
</View>
</View>
</View>

{/* Display warning icon if there's a memory warning */}
{shortMemoryWarning && isDownloaded && (
<TouchableRipple
testID="memory-warning-button"
onPress={handleWarningPress}
style={styles.warningContainer}>
<View style={styles.warningContent}>
<IconButton
icon="alert-circle-outline"
iconColor={theme.colors.error}
size={20}
style={styles.warningIcon}
{/* Display warning icon if there's a memory warning */}
{shortMemoryWarning && isDownloaded && (
<TouchableRipple
testID="memory-warning-button"
onPress={handleWarningPress}
style={styles.warningContainer}>
<View style={styles.warningContent}>
<IconButton
icon="alert-circle-outline"
iconColor={theme.colors.error}
size={20}
style={styles.warningIcon}
/>
<Text style={styles.warningText}>{shortMemoryWarning}</Text>
</View>
</TouchableRipple>
)}

{isDownloading && (
<>
<ProgressBar
testID="download-progress-bar"
progress={modelStore.getDownloadProgress(model.id)}
color={theme.colors.tertiary}
style={styles.progressBar}
/>
<Text style={styles.warningText}>{shortMemoryWarning}</Text>
</View>
</TouchableRipple>
)}

{isDownloading && (
<>
<ProgressBar
testID="download-progress-bar"
progress={modelStore.getDownloadProgress(model.id)}
color={theme.colors.tertiary}
style={styles.progressBar}
/>
{model.downloadSpeed && (
<Paragraph style={styles.downloadSpeed}>
{model.downloadSpeed}
</Paragraph>
)}
</>
)}
{model.downloadSpeed && (
<Paragraph style={styles.downloadSpeed}>
{model.downloadSpeed}
</Paragraph>
)}
</>
)}
</View>

<Divider style={styles.divider} />

{isDownloaded ? (
<Card.Actions style={styles.actions}>
<Button
Expand Down
20 changes: 19 additions & 1 deletion src/screens/ModelsScreen/ModelCard/styles.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,26 @@ export const createStyles = (theme: Theme) =>
},
modelDescription: {
fontSize: 12,
marginTop: 2,
marginVertical: 4,
color: theme.colors.onSurfaceVariant,
},
skillsLabel: {
fontSize: 12,
fontWeight: 'bold',
color: theme.colors.primary,
marginRight: 4,
},
skillsText: {
fontSize: 12,
color: theme.colors.onSurfaceVariant,
flexShrink: 1,
},
descriptionContainer: {
marginTop: 4,
flexDirection: 'row',
flexWrap: 'wrap',
alignItems: 'flex-start',
},
hfButton: {
margin: 0,
padding: 0,
Expand Down Expand Up @@ -125,6 +142,7 @@ export const createStyles = (theme: Theme) =>
},
storageErrorText: {
fontWeight: 'bold',
marginHorizontal: 8,
},
loadingContainer: {
flexDirection: 'row',
Expand Down
6 changes: 5 additions & 1 deletion src/store/ModelStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ class ModelStore {
};

mergeModelLists = () => {
const mergedModels = [...this.models]; // Start with persisted models
// Start with persisted models, but filter out non-downloaded preset models
const mergedModels = [...this.models].filter(
model => model.origin !== ModelOrigin.PRESET || model.isDownloaded,
);

// Handle PRESET models using defaultModels as reference
defaultModels.forEach(defaultModel => {
Expand Down Expand Up @@ -674,6 +677,7 @@ class ModelStore {
author: '',
name: filename,
size: 0, // Placeholder for UI to ignore
description: '',
params: 0, // Placeholder for UI to ignore
isDownloaded: true,
downloadUrl: '',
Expand Down
2 changes: 1 addition & 1 deletion src/store/UIStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export class UIStore {

displayMemUsage = false;

iOSBackgroundDownloading = false;
iOSBackgroundDownloading = true;

benchmarkShareDialog = {
shouldShow: true,
Expand Down
1 change: 1 addition & 0 deletions src/store/__tests__/ModelStore.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ describe('ModelStore', () => {
...completionSettingsWithoutTemperature, // Use the completionSettings without temperature - simulates new parameters
n_predict: 101010,
},
isDownloaded: true, // if not downloaded, it will be removed
};

modelStore.models[0] = existingModel;
Expand Down
Loading
Loading