Skip to content

Commit

Permalink
Structured JSON support
Browse files Browse the repository at this point in the history
  • Loading branch information
Stephen D committed Dec 8, 2024
1 parent 1728591 commit 7a903f9
Show file tree
Hide file tree
Showing 7 changed files with 148 additions and 10 deletions.
44 changes: 43 additions & 1 deletion Cargo.lock

Some generated files are not rendered by default. Learn more about how customized files appear on GitHub.

1 change: 1 addition & 0 deletions Cargo.toml
Original file line number Diff line number Diff line change
Expand Up @@ -23,6 +23,7 @@ text-splitter = { version = "0.13.1", optional = true }
regex = { version = "1.9.3", optional = true }
async-stream = "0.3.5"
http = {version = "1.1.0", optional = true }
schemars = "0.8.21"

[features]
default = ["reqwest/default-tls"]
Expand Down
55 changes: 55 additions & 0 deletions examples/structured_output.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,55 @@
use ollama_rs::{
generation::{
completion::request::GenerationRequest,
options::GenerationOptions,
parameters::{schema_for, FormatType, JsonSchema, JsonStructure},
},
Ollama,
};
use serde::Deserialize;

#[derive(JsonSchema, Deserialize, Debug)]
enum Temperature {
Warm,
Cold,
}

#[allow(dead_code)]
#[derive(JsonSchema, Deserialize, Debug)]
struct Output {
country: String,
capital: String,
languages: Vec<String>,
temperature: Temperature,
}

#[tokio::main]
async fn main() -> Result<(), Box<dyn std::error::Error>> {
let ollama = Ollama::default();
let model = "llama3.1:latest".to_string();
let prompt = "Tell me about the country north of the USA".to_string();

let format = FormatType::StructuredJson(JsonStructure::new::<Output>());
let res = ollama
.generate(
GenerationRequest::new(model, prompt)
.format(format)
.options(GenerationOptions::default().temperature(0.0)),
)
.await?;

let resp: Output = serde_json::from_str(&res.response)?;

// Output {
// country: "Canada",
// capital: "Ottawa",
// languages: [
// "English",
// "French",
// ],
// temperature: Cold,
// }
dbg!(resp);

Ok(())
}
6 changes: 3 additions & 3 deletions src/generation/chat/request.rs
Original file line number Diff line number Diff line change
@@ -1,11 +1,11 @@
use serde::{Deserialize, Serialize};
use serde::Serialize;

use crate::generation::{options::GenerationOptions, parameters::FormatType};

use super::ChatMessage;

/// A chat message request to Ollama.
#[derive(Debug, Clone, Serialize, Deserialize)]
#[derive(Debug, Clone, Serialize)]
pub struct ChatMessageRequest {
#[serde(rename = "model")]
pub model_name: String,
Expand Down Expand Up @@ -41,7 +41,7 @@ impl ChatMessageRequest {
self
}

// The format to return a response in. Currently the only accepted value is `json`
/// The format to return a response in.
pub fn format(mut self, format: FormatType) -> Self {
self.format = Some(format);
self
Expand Down
2 changes: 1 addition & 1 deletion src/generation/completion/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -91,7 +91,7 @@ impl GenerationRequest {
self
}

// The format to return a response in. Currently the only accepted value is `json`
/// The format to return a response in.
pub fn format(mut self, format: FormatType) -> Self {
self.format = Some(format);
self
Expand Down
2 changes: 1 addition & 1 deletion src/generation/functions/request.rs
Original file line number Diff line number Diff line change
Expand Up @@ -33,7 +33,7 @@ impl FunctionCallRequest {
self
}

// The format to return a response in. Currently the only accepted value is `json`
/// The format to return a response in.
pub fn format(mut self, format: FormatType) -> Self {
self.chat.format = Some(format);
self
Expand Down
48 changes: 44 additions & 4 deletions src/generation/parameters.rs
Original file line number Diff line number Diff line change
@@ -1,10 +1,50 @@
use serde::{Deserialize, Serialize};
use schemars::{gen::SchemaSettings, schema::RootSchema};
pub use schemars::{schema_for, JsonSchema};
use serde::{Serialize, Serializer};

/// The format to return a response in. Currently the only accepted value is `json`
#[derive(Debug, Serialize, Deserialize, Clone)]
#[serde(rename_all = "lowercase")]
/// The format to return a response in.
#[derive(Debug, Clone)]
pub enum FormatType {
Json,

/// Requires Ollama 0.5.0 or greater.
StructuredJson(JsonStructure),
}

impl Serialize for FormatType {
fn serialize<S>(&self, serializer: S) -> Result<S::Ok, S::Error>
where
S: Serializer,
{
match self {
FormatType::Json => serializer.serialize_str("json"),
FormatType::StructuredJson(s) => s.schema.serialize(serializer),
}
}
}

/// Represents a serialized JSON schema. You can create this by converting
/// a JsonSchema:
/// ```rust
/// let json_schema = schema_for!(Output);
/// let serialized: SerializedJsonSchema = json_schema.into();
/// ```
#[derive(Debug, Clone)]
pub struct JsonStructure {
schema: RootSchema,
}

impl JsonStructure {
pub fn new<T: JsonSchema>() -> Self {
// Need to do this because Ollama doesn't support $refs (references in the schema)
// So we have to explicitly turn them off
let mut settings = SchemaSettings::draft07();
settings.inline_subschemas = true;
let generator = settings.into_generator();
let schema = generator.into_root_schema_for::<T>();

Self { schema }
}
}

/// Used to control how long a model stays loaded in memory, by default models are unloaded after 5 minutes of inactivity
Expand Down

0 comments on commit 7a903f9

Please sign in to comment.