Skip to content

Commit

Permalink
feat add support for generation config
Browse files Browse the repository at this point in the history
  • Loading branch information
alextekartik committed Oct 23, 2024
1 parent dc78a14 commit 4f31abd
Show file tree
Hide file tree
Showing 8 changed files with 163 additions and 6 deletions.
74 changes: 74 additions & 0 deletions vertex_ai/lib/src/generation_config.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,74 @@
import 'package:tekartik_firebase_vertex_ai/vertex_ai.dart';

/// Configuration options for model generation and outputs.
final class GenerationConfig {
/// Number of generated responses to return.
///
/// This value must be between [1, 8], inclusive. If unset, this will default
/// to 1.
final int? candidateCount;

/// The set of character sequences (up to 5) that will stop output generation.
///
/// If specified, the API will stop at the first appearance of a stop
/// sequence. The stop sequence will not be included as part of the response.
final List<String> stopSequences;

/// The maximum number of tokens to include in a candidate.
///
/// If unset, this will default to output_token_limit specified in the `Model`
/// specification.
final int? maxOutputTokens;

/// Controls the randomness of the output.
///
/// Note: The default value varies by model.
///
/// Values can range from `[0.0, infinity]`, inclusive. A value temperature
/// must be greater than 0.0.
final double? temperature;

/// The maximum cumulative probability of tokens to consider when sampling.
///
/// The model uses combined Top-k and nucleus sampling. Tokens are sorted
/// based on their assigned probabilities so that only the most likely tokens
/// are considered. Top-k sampling directly limits the maximum number of
/// tokens to consider, while Nucleus sampling limits number of tokens based
/// on the cumulative probability.
///
/// Note: The default value varies by model.
final double? topP;

/// The maximum number of tokens to consider when sampling.
///
/// The model uses combined Top-k and nucleus sampling. Top-k sampling
/// considers the set of `top_k` most probable tokens. Defaults to 40.
///
/// Note: The default value varies by model.
final int? topK;

/// Output response mimetype of the generated candidate text.
///
/// Supported mimetype:
/// - `text/plain`: (default) Text output.
/// - `application/json`: JSON response in the candidates.
final String? responseMimeType;

/// Output response schema of the generated candidate text.
///
/// - Note: This only applies when the specified ``responseMIMEType`` supports
/// a schema; currently this is limited to `application/json`.
final Schema? responseSchema;

/// Constructor
GenerationConfig({
this.candidateCount,
this.stopSequences = const [],
this.maxOutputTokens,
this.temperature,
this.topP,
this.topK,
this.responseMimeType,
this.responseSchema,
});
}
7 changes: 7 additions & 0 deletions vertex_ai/lib/src/schema.dart
Original file line number Diff line number Diff line change
@@ -0,0 +1,7 @@
import 'package:tekartik_app_json_schema/json_schema.dart';

/// Schema
typedef Schema = JsonSchema;

/// Schema type
typedef SchemaType = JsonSchemaType;
3 changes: 2 additions & 1 deletion vertex_ai/lib/src/vertex_ai.dart
Original file line number Diff line number Diff line change
Expand Up @@ -23,5 +23,6 @@ abstract class FirebaseVertexAi
/// The optional [safetySettings] and [generationConfig] can be used to
/// control and guide the generation. See [SafetySetting] and
/// [GenerationConfig] for details.
VaiGenerativeModel generativeModel({String? model});
VaiGenerativeModel generativeModel(
{String? model, GenerationConfig? generationConfig});
}
2 changes: 2 additions & 0 deletions vertex_ai/lib/vertex_ai.dart
Original file line number Diff line number Diff line change
@@ -1,3 +1,5 @@
export 'src/generation_config.dart' show GenerationConfig;
export 'src/schema.dart' show Schema, SchemaType;
export 'src/vertex_ai.dart' show FirebaseVertexAi, FirebaseVertexAiService;
export 'src/vertex_ai_api.dart' show VaiGenerateContentResponse;
export 'src/vertex_ai_constant.dart';
Expand Down
5 changes: 5 additions & 0 deletions vertex_ai/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,11 @@ dependencies:
url: https://github.com/tekartik/firebase.dart
ref: dart3a
path: firebase
tekartik_app_json_schema:
git:
url: https://github.com/tekartik/app_common_utils.dart
ref: dart3a
path: app_json_schema
# path: ^1.8.0

dev_dependencies:
Expand Down
6 changes: 5 additions & 1 deletion vertex_ai_google/example/simple_prompt.dart
Original file line number Diff line number Diff line change
Expand Up @@ -8,7 +8,11 @@ import 'simple_raw_prompt.dart';
Future<void> main(List<String> args) async {
var apiKey = await getGeminiApiKey();
var vertexAi = FirebaseVertexAiGoogle(apiKey: apiKey);
final model = vertexAi.generativeModel();
final model = vertexAi.generativeModel(
generationConfig: GenerationConfig(
responseMimeType: 'application/json',
responseSchema:
Schema.object(properties: {'total': Schema.number()})));

final prompt = 'Sum 1 and 4';
final content = Content.text(prompt);
Expand Down
68 changes: 64 additions & 4 deletions vertex_ai_google/lib/src/vertex_ai_google.dart
Original file line number Diff line number Diff line change
@@ -1,7 +1,7 @@
import 'package:google_generative_ai/google_generative_ai.dart' as gai;
import 'package:tekartik_common_utils/list_utils.dart';
import 'package:tekartik_firebase/firebase_mixin.dart';
import 'package:tekartik_firebase_local/firebase_local.dart';

import 'package:tekartik_firebase_vertex_ai/vertex_ai.dart';

import 'vertex_ai_model_google.dart';
Expand All @@ -22,6 +22,7 @@ class _FirebaseVertexAiServiceGoogle
final String apiKey;

_FirebaseVertexAiServiceGoogle({required this.apiKey});

@override
FirebaseVertexAiGoogle vertexAi(FirebaseApp app) {
return getInstance(app, () {
Expand All @@ -43,10 +44,13 @@ class _FirebaseVertexAiGoogle
FirebaseApp get app => appLocal;

@override
VaiGenerativeModel generativeModel({String? model}) {
VaiGenerativeModel generativeModel(
{String? model, GenerationConfig? generationConfig}) {
model ??= vertexAiModelGemini1dot5Flash;
var nativeModel =
gai.GenerativeModel(model: model, apiKey: serviceGoogle.apiKey);
var nativeModel = gai.GenerativeModel(
model: model,
apiKey: serviceGoogle.apiKey,
generationConfig: generationConfig?.toGaiGenerationConfig());
return VaiGenerativeModelGoogle(this, nativeModel);
}
}
Expand All @@ -60,3 +64,59 @@ abstract class FirebaseVertexAiGoogle implements FirebaseVertexAi {
return service.vertexAi(appLocal) as FirebaseVertexAiGoogle;
}
}

extension on SchemaType {
gai.SchemaType toGaiSchemaType() {
switch (this) {
case SchemaType.object:
return gai.SchemaType.object;
case SchemaType.array:
return gai.SchemaType.array;
case SchemaType.integer:
return gai.SchemaType.integer;
case SchemaType.boolean:
return gai.SchemaType.boolean;
case SchemaType.string:
return gai.SchemaType.string;
case SchemaType.number:
return gai.SchemaType.number;
}
}
}

extension on Map<String, Object?> {
List<String>? keysWithout(List<String>? without) {
var list = List.of(keys);
if (without != null) {
list.removeWhere((key) => without.contains(key));
}
return list.nonEmpty();
}
}

extension on Schema {
gai.Schema toGaiSchema() {
return gai.Schema(type.toGaiSchemaType(),
items: items?.toGaiSchema(),
format: format,
description: description,
enumValues: enumValues,
nullable: nullable,
properties:
properties?.map((key, value) => MapEntry(key, value.toGaiSchema())),
requiredProperties: properties?.keysWithout(optionalProperties));
}
}

extension on GenerationConfig {
gai.GenerationConfig toGaiGenerationConfig() {
return gai.GenerationConfig(
candidateCount: candidateCount,
maxOutputTokens: maxOutputTokens,
temperature: temperature,
topP: topP,
topK: topK,
responseMimeType: responseMimeType,
responseSchema: responseSchema?.toGaiSchema());
}
}
4 changes: 4 additions & 0 deletions vertex_ai_google/pubspec.yaml
Original file line number Diff line number Diff line change
Expand Up @@ -26,6 +26,10 @@ dependencies:
url: https://github.com/tekartik/firebase.dart
ref: dart3a
path: firebase_local
tekartik_common_utils:
git:
url: https://github.com/tekartik/common_utils.dart
ref: dart3a

dev_dependencies:
lints: ">=5.0.0"
Expand Down

0 comments on commit 4f31abd

Please sign in to comment.