Skip to content

Commit

Permalink
[Chore] Add hfModelFile to preset models for alignment with HF-source…
Browse files Browse the repository at this point in the history
…d models & shortens the preset model lists (#145)

* feat: add skills to the models

* chore: added hfModelFile to preset models - so benchmark for preset models can be shared

* feat: enable iOS downloads to occur in the background by default.

* chore: bump model version and minor fix for model merge algo

* fix: numeric values (stored as string) for CompletionSettings
  • Loading branch information
a-ghorbani authored Dec 23, 2024
1 parent aa95a19 commit f655ef3
Show file tree
Hide file tree
Showing 10 changed files with 206 additions and 461 deletions.
1 change: 1 addition & 0 deletions jest/fixtures/models.ts
Original file line number Diff line number Diff line change
Expand Up @@ -62,6 +62,7 @@ export const mockBasicModel: Model = {
name: 'Test Model 1',
author: 'test-author',
type: 'Test Model Type',
description: 'Test model description',
size: 2 * 10 ** 9,
params: 2 * 10 ** 9,
isDownloaded: false,
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -44,17 +44,16 @@ export const ModelFileCard: FC<ModelFileCardProps> = observer(
const isBookmarked = computed(() =>
modelStore.models.some(
model =>
model.origin === ModelOrigin.HF &&
//model.origin === ModelOrigin.HF &&
model.hfModelFile?.oid === modelFile.oid,
),
).get();

const isDownloaded = computed(() =>
modelStore.models.some(
model =>
model.origin === ModelOrigin.HF &&
model.hfModelFile?.oid === modelFile.oid &&
model.isDownloaded,
//model.origin === ModelOrigin.HF &&
model.hfModelFile?.oid === modelFile.oid && model.isDownloaded,
),
).get();

Expand All @@ -69,7 +68,9 @@ export const ModelFileCard: FC<ModelFileCardProps> = observer(
const model = modelStore.models.find(
(m: Model) => m.hfModelFile?.oid === modelFile.oid,
);
if (model && model.isDownloaded) {
if (model && model.origin === ModelOrigin.PRESET) {
Alert.alert('Cannot Remove', 'The model is preset.');
} else if (model && model.isDownloaded) {
Alert.alert(
'Cannot Remove',
'The model is downloaded. Please delete the file first.',
Expand Down
135 changes: 89 additions & 46 deletions src/screens/ModelsScreen/ModelCard/ModelCard.tsx
Original file line number Diff line number Diff line change
@@ -1,5 +1,5 @@
import React, {useCallback, useState, useEffect} from 'react';
import {Alert, Linking, View, Image} from 'react-native';
import {Alert, Linking, View} from 'react-native';

import {observer} from 'mobx-react-lite';
import {useNavigation} from '@react-navigation/native';
Expand Down Expand Up @@ -28,7 +28,10 @@ import {uiStore, modelStore} from '../../../store';

import {chatTemplates} from '../../../utils/chat';
import {getModelDescription, L10nContext} from '../../../utils';
import {validateCompletionSettings} from '../../../utils/modelSettings';
import {
COMPLETION_PARAMS_METADATA,
validateCompletionSettings,
} from '../../../utils/modelSettings';
import {Model, ModelOrigin, RootDrawerParamList} from '../../../utils/types';

type ChatScreenNavigationProp = DrawerNavigationProp<RootDrawerParamList>;
Expand Down Expand Up @@ -100,15 +103,54 @@ export const ModelCard: React.FC<ModelCardProps> = observer(
}, []);

const handleSaveSettings = useCallback(() => {
const {isValid, errors} = validateCompletionSettings(
tempCompletionSettings,
// Convert string values to numbers where needed
const processedSettings = Object.entries(tempCompletionSettings).reduce(
(acc, [key, value]) => {
const metadata = COMPLETION_PARAMS_METADATA[key];
if (metadata?.validation.type === 'numeric') {
// Handle numeric conversion
let numValue: number;
if (typeof value === 'string') {
numValue = Number(value);
} else if (typeof value === 'number') {
numValue = value;
} else {
// If it's neither string nor number, treat as invalid. Most probably won't happen.
acc.errors[key] = 'Must be a valid number';
return acc;
}

if (Number.isNaN(numValue)) {
acc.errors[key] = 'Must be a valid number';
} else {
acc.settings[key] = numValue;
}
} else {
// For non-numeric values, keep as is
acc.settings[key] = value;
}
return acc;
},
{settings: {}, errors: {}} as {
settings: typeof tempCompletionSettings;
errors: Record<string, string>;
},
);

if (!isValid) {
// Validate the converted values
const validationResult = validateCompletionSettings(
processedSettings.settings,
);
const allErrors = {
...processedSettings.errors,
...validationResult.errors,
};

if (Object.keys(allErrors).length > 0) {
Alert.alert(
'Invalid Values',
'Please correct the following:\n' +
Object.entries(errors)
Object.entries(allErrors)
.map(([key, msg]) => `• ${key}: ${msg}`)
.join('\n'),
[{text: 'OK'}],
Expand All @@ -118,7 +160,7 @@ export const ModelCard: React.FC<ModelCardProps> = observer(

// All validations passed, save the settings
modelStore.updateModelChatTemplate(model.id, tempChatTemplate);
modelStore.updateCompletionSettings(model.id, tempCompletionSettings);
modelStore.updateCompletionSettings(model.id, processedSettings.settings);
handleCloseSettings();
}, [
model.id,
Expand Down Expand Up @@ -300,12 +342,6 @@ export const ModelCard: React.FC<ModelCardProps> = observer(
isActiveModel && {backgroundColor: theme.colors.tertiaryContainer},
{borderColor: theme.colors.primary},
]}>
{isHfModel && (
<Image
source={require('../../../assets/icon-hf.png')}
style={styles.hfBadge}
/>
)}
<View style={styles.cardInner}>
<View style={styles.cardContent}>
<View style={styles.headerRow}>
Expand All @@ -327,45 +363,52 @@ export const ModelCard: React.FC<ModelCardProps> = observer(
<Text style={styles.modelDescription}>
{getModelDescription(model, isActiveModel, modelStore)}
</Text>
{model.description && (
<View style={styles.descriptionContainer}>
<Text style={styles.skillsLabel}>Skills: </Text>
<Text style={styles.skillsText}>{model.description}</Text>
</View>
)}
</View>
</View>
</View>

{/* Display warning icon if there's a memory warning */}
{shortMemoryWarning && isDownloaded && (
<TouchableRipple
testID="memory-warning-button"
onPress={handleWarningPress}
style={styles.warningContainer}>
<View style={styles.warningContent}>
<IconButton
icon="alert-circle-outline"
iconColor={theme.colors.error}
size={20}
style={styles.warningIcon}
{/* Display warning icon if there's a memory warning */}
{shortMemoryWarning && isDownloaded && (
<TouchableRipple
testID="memory-warning-button"
onPress={handleWarningPress}
style={styles.warningContainer}>
<View style={styles.warningContent}>
<IconButton
icon="alert-circle-outline"
iconColor={theme.colors.error}
size={20}
style={styles.warningIcon}
/>
<Text style={styles.warningText}>{shortMemoryWarning}</Text>
</View>
</TouchableRipple>
)}

{isDownloading && (
<>
<ProgressBar
testID="download-progress-bar"
progress={modelStore.getDownloadProgress(model.id)}
color={theme.colors.tertiary}
style={styles.progressBar}
/>
<Text style={styles.warningText}>{shortMemoryWarning}</Text>
</View>
</TouchableRipple>
)}

{isDownloading && (
<>
<ProgressBar
testID="download-progress-bar"
progress={modelStore.getDownloadProgress(model.id)}
color={theme.colors.tertiary}
style={styles.progressBar}
/>
{model.downloadSpeed && (
<Paragraph style={styles.downloadSpeed}>
{model.downloadSpeed}
</Paragraph>
)}
</>
)}
{model.downloadSpeed && (
<Paragraph style={styles.downloadSpeed}>
{model.downloadSpeed}
</Paragraph>
)}
</>
)}
</View>

<Divider style={styles.divider} />

{isDownloaded ? (
<Card.Actions style={styles.actions}>
<Button
Expand Down
20 changes: 19 additions & 1 deletion src/screens/ModelsScreen/ModelCard/styles.ts
Original file line number Diff line number Diff line change
Expand Up @@ -83,9 +83,26 @@ export const createStyles = (theme: Theme) =>
},
modelDescription: {
fontSize: 12,
marginTop: 2,
marginVertical: 4,
color: theme.colors.onSurfaceVariant,
},
skillsLabel: {
fontSize: 12,
fontWeight: 'bold',
color: theme.colors.primary,
marginRight: 4,
},
skillsText: {
fontSize: 12,
color: theme.colors.onSurfaceVariant,
flexShrink: 1,
},
descriptionContainer: {
marginTop: 4,
flexDirection: 'row',
flexWrap: 'wrap',
alignItems: 'flex-start',
},
hfButton: {
margin: 0,
padding: 0,
Expand Down Expand Up @@ -125,6 +142,7 @@ export const createStyles = (theme: Theme) =>
},
storageErrorText: {
fontWeight: 'bold',
marginHorizontal: 8,
},
loadingContainer: {
flexDirection: 'row',
Expand Down
6 changes: 5 additions & 1 deletion src/store/ModelStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -85,7 +85,10 @@ class ModelStore {
};

mergeModelLists = () => {
const mergedModels = [...this.models]; // Start with persisted models
// Start with persisted models, but filter out non-downloaded preset models
const mergedModels = [...this.models].filter(
model => model.origin !== ModelOrigin.PRESET || model.isDownloaded,
);

// Handle PRESET models using defaultModels as reference
defaultModels.forEach(defaultModel => {
Expand Down Expand Up @@ -674,6 +677,7 @@ class ModelStore {
author: '',
name: filename,
size: 0, // Placeholder for UI to ignore
description: '',
params: 0, // Placeholder for UI to ignore
isDownloaded: true,
downloadUrl: '',
Expand Down
2 changes: 1 addition & 1 deletion src/store/UIStore.ts
Original file line number Diff line number Diff line change
Expand Up @@ -27,7 +27,7 @@ export class UIStore {

displayMemUsage = false;

iOSBackgroundDownloading = false;
iOSBackgroundDownloading = true;

benchmarkShareDialog = {
shouldShow: true,
Expand Down
1 change: 1 addition & 0 deletions src/store/__tests__/ModelStore.test.ts
Original file line number Diff line number Diff line change
Expand Up @@ -63,6 +63,7 @@ describe('ModelStore', () => {
...completionSettingsWithoutTemperature, // Use the completionSettings without temperature - simulates new parameters
n_predict: 101010,
},
isDownloaded: true, // if not downloaded, it will be removed
};

modelStore.models[0] = existingModel;
Expand Down
Loading

0 comments on commit f655ef3

Please sign in to comment.