Skip to content
New issue

Have a question about this project? Sign up for a free GitHub account to open an issue and contact its maintainers and the community.

By clicking “Sign up for GitHub”, you agree to our terms of service and privacy statement. We’ll occasionally send you account related emails.

Already on GitHub? Sign in to your account

fix: #3491 - Unable to use tensorrt-llm #3741

Merged
merged 2 commits into from
Sep 30, 2024
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension


Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 9 additions & 0 deletions extensions/tensorrt-llm-extension/jest.config.js
Original file line number Diff line number Diff line change
@@ -0,0 +1,9 @@
/** @type {import('ts-jest').JestConfigWithTsJest} */
module.exports = {
preset: 'ts-jest',
testEnvironment: 'node',
transform: {
'node_modules/@janhq/core/.+\\.(j|t)s?$': 'ts-jest',
},
transformIgnorePatterns: ['node_modules/(?!@janhq/core/.*)'],
}
8 changes: 7 additions & 1 deletion extensions/tensorrt-llm-extension/package.json
Original file line number Diff line number Diff line change
Expand Up @@ -22,6 +22,7 @@
"tensorrtVersion": "0.1.8",
"provider": "nitro-tensorrt-llm",
"scripts": {
"test": "jest",
"build": "tsc --module commonjs && rollup -c rollup.config.ts",
"build:publish:win32": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
"build:publish:linux": "rimraf *.tgz --glob && yarn build && cpx \"bin/**\" \"dist/bin\" && npm pack && cpx *.tgz ../../pre-install",
Expand Down Expand Up @@ -49,7 +50,12 @@
"rollup-plugin-sourcemaps": "^0.6.3",
"rollup-plugin-typescript2": "^0.36.0",
"run-script-os": "^1.1.6",
"typescript": "^5.2.2"
"typescript": "^5.2.2",
"@types/jest": "^29.5.12",
"jest": "^29.7.0",
"jest-junit": "^16.0.0",
"jest-runner": "^29.7.0",
"ts-jest": "^29.2.5"
},
"dependencies": {
"@janhq/core": "file:../../core",
Expand Down
4 changes: 2 additions & 2 deletions extensions/tensorrt-llm-extension/rollup.config.ts
Original file line number Diff line number Diff line change
Expand Up @@ -23,10 +23,10 @@ export default [
DOWNLOAD_RUNNER_URL:
process.platform === 'win32'
? JSON.stringify(
'https://github.com/janhq/nitro-tensorrt-llm/releases/download/windows-v<version>-tensorrt-llm-v0.7.1/nitro-windows-v<version>-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz'
'https://github.com/janhq/cortex.tensorrt-llm/releases/download/windows-v<version>-tensorrt-llm-v0.7.1/nitro-windows-v<version>-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz'
)
: JSON.stringify(
'https://github.com/janhq/nitro-tensorrt-llm/releases/download/linux-v<version>/nitro-linux-v<version>-amd64-tensorrt-llm-<gpuarch>.tar.gz'
'https://github.com/janhq/cortex.tensorrt-llm/releases/download/linux-v<version>/nitro-linux-v<version>-amd64-tensorrt-llm-<gpuarch>.tar.gz'
),
NODE: JSON.stringify(`${packageJson.name}/${packageJson.node}`),
INFERENCE_URL: JSON.stringify(
Expand Down
186 changes: 186 additions & 0 deletions extensions/tensorrt-llm-extension/src/index.test.ts
Original file line number Diff line number Diff line change
@@ -0,0 +1,186 @@
import TensorRTLLMExtension from '../src/index'
import {
executeOnMain,
systemInformation,
fs,
baseName,
joinPath,
downloadFile,
} from '@janhq/core'

jest.mock('@janhq/core', () => ({
...jest.requireActual('@janhq/core/node'),
LocalOAIEngine: jest.fn().mockImplementation(function () {
// @ts-ignore
this.registerModels = () => {
return Promise.resolve()
}
// @ts-ignore
return this
}),
systemInformation: jest.fn(),
fs: {
existsSync: jest.fn(),
mkdir: jest.fn(),
},
joinPath: jest.fn(),
baseName: jest.fn(),
downloadFile: jest.fn(),
executeOnMain: jest.fn(),
showToast: jest.fn(),
events: {
emit: jest.fn(),
// @ts-ignore
on: (event, func) => {
func({ fileName: './' })
},
off: jest.fn(),
},
}))

// @ts-ignore
global.COMPATIBILITY = {
platform: ['win32'],
}
// @ts-ignore
global.PROVIDER = 'tensorrt-llm'
// @ts-ignore
global.INFERENCE_URL = 'http://localhost:5000'
// @ts-ignore
global.NODE = 'node'
// @ts-ignore
global.MODELS = []
// @ts-ignore
global.TENSORRT_VERSION = ''
// @ts-ignore
global.DOWNLOAD_RUNNER_URL = ''

describe('TensorRTLLMExtension', () => {
let extension: TensorRTLLMExtension

beforeEach(() => {
// @ts-ignore
extension = new TensorRTLLMExtension()
jest.clearAllMocks()
})

describe('compatibility', () => {
it('should return the correct compatibility', () => {
const result = extension.compatibility()
expect(result).toEqual({
platform: ['win32'],
})
})
})

describe('install', () => {
it('should install if compatible', async () => {
const mockSystemInfo: any = {
osInfo: { platform: 'win32' },
gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
}
;(executeOnMain as jest.Mock).mockResolvedValue({})
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
;(fs.existsSync as jest.Mock).mockResolvedValue(false)
;(fs.mkdir as jest.Mock).mockResolvedValue(undefined)
;(baseName as jest.Mock).mockResolvedValue('./')
;(joinPath as jest.Mock).mockResolvedValue('./')
;(downloadFile as jest.Mock).mockResolvedValue({})

await extension.install()

expect(executeOnMain).toHaveBeenCalled()
})

it('should not install if not compatible', async () => {
const mockSystemInfo: any = {
osInfo: { platform: 'linux' },
gpuSetting: { gpus: [{ arch: 'pascal', name: 'NVIDIA GPU' }] },
}
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)

jest.spyOn(extension, 'registerModels').mockReturnValue(Promise.resolve())
await extension.install()

expect(executeOnMain).not.toHaveBeenCalled()
})
})

describe('installationState', () => {
it('should return NotCompatible if not compatible', async () => {
const mockSystemInfo: any = {
osInfo: { platform: 'linux' },
gpuSetting: { gpus: [{ arch: 'pascal', name: 'NVIDIA GPU' }] },
}
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)

const result = await extension.installationState()

expect(result).toBe('NotCompatible')
})

it('should return Installed if executable exists', async () => {
const mockSystemInfo: any = {
osInfo: { platform: 'win32' },
gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
}
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
;(fs.existsSync as jest.Mock).mockResolvedValue(true)

const result = await extension.installationState()

expect(result).toBe('Installed')
})

it('should return NotInstalled if executable does not exist', async () => {
const mockSystemInfo: any = {
osInfo: { platform: 'win32' },
gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
}
;(systemInformation as jest.Mock).mockResolvedValue(mockSystemInfo)
;(fs.existsSync as jest.Mock).mockResolvedValue(false)

const result = await extension.installationState()

expect(result).toBe('NotInstalled')
})
})

describe('isCompatible', () => {
it('should return true for compatible system', () => {
const mockInfo: any = {
osInfo: { platform: 'win32' },
gpuSetting: { gpus: [{ arch: 'ampere', name: 'NVIDIA GPU' }] },
}

const result = extension.isCompatible(mockInfo)

expect(result).toBe(true)
})

it('should return false for incompatible system', () => {
const mockInfo: any = {
osInfo: { platform: 'linux' },
gpuSetting: { gpus: [{ arch: 'pascal', name: 'AMD GPU' }] },
}

const result = extension.isCompatible(mockInfo)

expect(result).toBe(false)
})
})
})

describe('GitHub Release File URL Test', () => {
const url = 'https://github.com/janhq/cortex.tensorrt-llm/releases/download/windows-v0.1.8-tensorrt-llm-v0.7.1/nitro-windows-v0.1.8-tensorrt-llm-v0.7.1-amd64-all-arch.tar.gz';

it('should return a status code 200 for the release file URL', async () => {
const response = await fetch(url, { method: 'HEAD' });
expect(response.status).toBe(200);
});

it('should not return a 404 status', async () => {
const response = await fetch(url, { method: 'HEAD' });
expect(response.status).not.toBe(404);
});
});
3 changes: 1 addition & 2 deletions extensions/tensorrt-llm-extension/src/index.ts
Original file line number Diff line number Diff line change
Expand Up @@ -41,7 +41,6 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
override nodeModule = NODE

private supportedGpuArch = ['ampere', 'ada']
private supportedPlatform = ['win32', 'linux']

override compatibility() {
return COMPATIBILITY as unknown as Compatibility
Expand Down Expand Up @@ -191,7 +190,7 @@ export default class TensorRTLLMExtension extends LocalOAIEngine {
!!info.gpuSetting &&
!!firstGpu &&
info.gpuSetting.gpus.length > 0 &&
this.supportedPlatform.includes(info.osInfo.platform) &&
this.compatibility().platform.includes(info.osInfo.platform) &&
!!firstGpu.arch &&
firstGpu.name.toLowerCase().includes('nvidia') &&
this.supportedGpuArch.includes(firstGpu.arch)
Expand Down
3 changes: 2 additions & 1 deletion extensions/tensorrt-llm-extension/tsconfig.json
Original file line number Diff line number Diff line change
Expand Up @@ -16,5 +16,6 @@
"resolveJsonModule": true,
"typeRoots": ["node_modules/@types"]
},
"include": ["src"]
"include": ["src"],
"exclude": ["**/*.test.ts"]
}
9 changes: 7 additions & 2 deletions web/hooks/useDownloadModel.ts
Original file line number Diff line number Diff line change
Expand Up @@ -9,6 +9,8 @@
ModelArtifact,
DownloadState,
GpuSetting,
ModelFile,
dirName,
} from '@janhq/core'

import { useAtomValue, useSetAtom } from 'jotai'
Expand Down Expand Up @@ -36,8 +38,8 @@

const downloadModel = useCallback(
async (model: Model) => {
const childProgresses: DownloadState[] = model.sources.map(
(source: ModelArtifact) => ({

Check warning on line 42 in web/hooks/useDownloadModel.ts

View workflow job for this annotation

GitHub Actions / coverage-check

41-42 lines are not covered with tests
fileName: source.filename,
modelId: model.id,
time: {
Expand All @@ -55,7 +57,7 @@
)

// set an initial download state
setDownloadState({

Check warning on line 60 in web/hooks/useDownloadModel.ts

View workflow job for this annotation

GitHub Actions / coverage-check

60 line is not covered with tests
fileName: '',
modelId: model.id,
time: {
Expand All @@ -72,9 +74,9 @@
downloadState: 'downloading',
})

addDownloadingModel(model)
const gpuSettings = await getGpuSettings()
await localDownloadModel(

Check warning on line 79 in web/hooks/useDownloadModel.ts

View workflow job for this annotation

GitHub Actions / coverage-check

77-79 lines are not covered with tests
model,
ignoreSSL,
proxyEnabled ? proxy : '',
Expand All @@ -91,10 +93,13 @@
]
)

const abortModelDownload = useCallback(async (model: Model) => {
const abortModelDownload = useCallback(async (model: Model | ModelFile) => {
for (const source of model.sources) {

Check warning on line 97 in web/hooks/useDownloadModel.ts

View workflow job for this annotation

GitHub Actions / coverage-check

97 line is not covered with tests
const path = await joinPath(['models', model.id, source.filename])
const path =
'file_path' in model

Check warning on line 99 in web/hooks/useDownloadModel.ts

View workflow job for this annotation

GitHub Actions / coverage-check

99 line is not covered with tests
? await joinPath([await dirName(model.file_path), source.filename])
: await joinPath(['models', model.id, source.filename])
await abortDownload(path)

Check warning on line 102 in web/hooks/useDownloadModel.ts

View workflow job for this annotation

GitHub Actions / coverage-check

102 line is not covered with tests
}
}, [])

Expand All @@ -110,6 +115,6 @@
proxy: string,
gpuSettings?: GpuSetting
) =>
extensionManager

Check warning on line 118 in web/hooks/useDownloadModel.ts

View workflow job for this annotation

GitHub Actions / coverage-check

118 line is not covered with tests
.get<ModelExtension>(ExtensionTypeEnum.Model)
?.downloadModel(model, gpuSettings, { ignoreSSL, proxy })
Loading