-
Notifications
You must be signed in to change notification settings - Fork 0
/
Copy pathmodelPricing.ts
148 lines (128 loc) · 4.35 KB
/
modelPricing.ts
1
2
3
4
5
6
7
8
9
10
11
12
13
14
15
16
17
18
19
20
21
22
23
24
25
26
27
28
29
30
31
32
33
34
35
36
37
38
39
40
41
42
43
44
45
46
47
48
49
50
51
52
53
54
55
56
57
58
59
60
61
62
63
64
65
66
67
68
69
70
71
72
73
74
75
76
77
78
79
80
81
82
83
84
85
86
87
88
89
90
91
92
93
94
95
96
97
98
99
100
101
102
103
104
105
106
107
108
109
110
111
112
113
114
115
116
117
118
119
120
121
122
123
124
125
126
127
128
129
130
131
132
133
134
135
136
137
138
139
140
141
142
143
144
145
146
147
148
/*
* This file defines the pricing for the different model operations.
*/
// TRANSFORMS AND MODEL OPTIONS -------------------------------------------------
export enum Transforms {
EMBEDDING = "EMBEDDING",
FILL_MASK = "FILL_MASK",
DIFFUSION_GENERATE = "DIFFUSION_GENERATE",
PROMOTER_ACTIVITY = "PROMOTER_ACTIVITY",
TRACKS_PREDICTION = "TRACKS_PREDICTION",
}
export enum ModelOptions {
abdiffusion = "abdiffusion",
borzoi_human_fold0 = "borzoi_human_fold0",
lcdna = "lcdna",
ginkgo_aa0_650M = "ginkgo-aa0-650M",
esm2_650M = "esm2-650M",
esm2_3B = "esm2-3B",
ginkgo_maskedlm_3utr_v1 = "ginkgo-maskedlm-3utr-v1",
}
// REQUEST TYPES ----------------------------------------------------------------
export type MeanEmbeddingParams = {
transform: Transforms.EMBEDDING;
sequence: string;
model:
| ModelOptions.esm2_650M
| ModelOptions.esm2_3B
| ModelOptions.ginkgo_maskedlm_3utr_v1
| ModelOptions.ginkgo_aa0_650M;
};
export type MaskedInferenceParams = {
transform: Transforms.FILL_MASK;
sequence: string;
model:
| ModelOptions.esm2_650M
| ModelOptions.esm2_3B
| ModelOptions.ginkgo_maskedlm_3utr_v1
| ModelOptions.ginkgo_aa0_650M;
};
export type PromoterActivityParams = {
transform: Transforms.PROMOTER_ACTIVITY;
promoter_sequence: string;
orf_sequence: string;
tissue_of_interest: Record<string, string[]>;
model: ModelOptions.borzoi_human_fold0;
};
export type TracksPredictionParams = {
transform: Transforms.TRACKS_PREDICTION;
sequence: string;
tracks: string[];
model: ModelOptions.borzoi_human_fold0;
};
export type DiffusionGenerateParams = {
transform: Transforms.DIFFUSION_GENERATE;
unmaskings_per_step: number;
sequence: string;
model: ModelOptions.abdiffusion | ModelOptions.lcdna;
};
// HELPER FUNCTIONS -------------------------------------------------------------
/**
* Counts the number of tokens in a given sequence, counting special tokens as 1.
*
* @param sequence - The sequence to count tokens from.
* @returns The number of tokens in the sequence.
*/
const getNumberOfTokens = (sequence: string): number => {
const tokenPattern = /(<[^>]+>|[acdefghiklmnpqrstvwy])/g;
const tokens = sequence.toLowerCase()?.match(tokenPattern);
return tokens ? tokens.length : 0;
};
/**
* Counts the number of masked tokens in a given sequence.
*
* @param sequence - The sequence to count masked tokens from.
* @returns The number of masked tokens in the sequence.
*/
const getNumberOfMaskedTokens = (sequence: string): number => {
const maskTokenPattern = /<mask>/g;
const maskTokens = sequence.toLowerCase().match(maskTokenPattern);
return maskTokens ? maskTokens.length : 0;
};
// PRICING FUNCTION --------------------------------------------------------------
/**
* Calculates the pricing for a given model operation based on the provided parameters.
*
* @param params - The parameters for the model operation, including the sequence, model,
* and transform type.
* @returns The calculated pricing for the model operation.
*/
export function getModelPricing(
params:
| MeanEmbeddingParams
| MaskedInferenceParams
| PromoterActivityParams
| TracksPredictionParams
| DiffusionGenerateParams
): number {
const TOKEN_COST_PER_MODEL = {
[ModelOptions.esm2_650M]: 0.00000018,
[ModelOptions.esm2_3B]: 0.00000025,
[ModelOptions.ginkgo_maskedlm_3utr_v1]: 0.00000018,
[ModelOptions.ginkgo_aa0_650M]: 0.00000018,
};
const COST_PER_MODEL_PASS = {
[ModelOptions.borzoi_human_fold0]: 0.0025,
[ModelOptions.abdiffusion]: 0.0002,
[ModelOptions.lcdna]: 0.01,
};
switch (params.transform) {
case Transforms.EMBEDDING:
return (
TOKEN_COST_PER_MODEL[params.model] * getNumberOfTokens(params.sequence)
);
case Transforms.FILL_MASK:
return (
TOKEN_COST_PER_MODEL[params.model] * getNumberOfTokens(params.sequence)
);
case Transforms.PROMOTER_ACTIVITY:
return COST_PER_MODEL_PASS[ModelOptions.borzoi_human_fold0]; // Fixed price per request
case Transforms.TRACKS_PREDICTION:
return 0.003 + 0.00003 * params.tracks.length; // $0.003 + $0.00003*tracks.length
case Transforms.DIFFUSION_GENERATE:
const n_passes =
getNumberOfMaskedTokens(params.sequence) / params.unmaskings_per_step;
const pass_cost = COST_PER_MODEL_PASS[params.model];
return pass_cost * n_passes;
}
}