Skip to content

Commit

Permalink
WIP
Browse files Browse the repository at this point in the history
  • Loading branch information
tellet-q committed Jan 2, 2025
1 parent c3a957e commit 80fb4fa
Show file tree
Hide file tree
Showing 6 changed files with 164 additions and 21 deletions.
31 changes: 23 additions & 8 deletions src/qdrant_client/collection.rs
Original file line number Diff line number Diff line change
Expand Up @@ -6,15 +6,9 @@ use tonic::Status;

use crate::auth::TokenInterceptor;
use crate::qdrant::collections_client::CollectionsClient;
use crate::qdrant::{
alias_operations, AliasOperations, ChangeAliases, CollectionClusterInfoRequest,
CollectionClusterInfoResponse, CollectionExistsRequest, CollectionOperationResponse,
CreateAlias, CreateCollection, DeleteAlias, DeleteCollection, GetCollectionInfoRequest,
GetCollectionInfoResponse, ListAliasesRequest, ListAliasesResponse,
ListCollectionAliasesRequest, ListCollectionsRequest, ListCollectionsResponse, RenameAlias,
UpdateCollection, UpdateCollectionClusterSetupRequest, UpdateCollectionClusterSetupResponse,
};
use crate::qdrant::{alias_operations, AliasOperations, ChangeAliases, CollectionClusterInfoRequest, CollectionClusterInfoResponse, CollectionExistsRequest, CollectionOperationResponse, CreateAlias, CreateCollection, DeleteAlias, DeleteCollection, GetCollectionInfoRequest, GetCollectionInfoResponse, ListAliasesRequest, ListAliasesResponse, ListCollectionAliasesRequest, ListCollectionsRequest, ListCollectionsResponse, RenameAlias, UpdateCollection, UpdateCollectionClusterSetupRequest, UpdateCollectionClusterSetupResponse};
use crate::qdrant_client::{Qdrant, QdrantResult};
use crate::qdrant_client::version_check::is_compatible;

/// # Collection operations
///
Expand All @@ -27,6 +21,26 @@ impl Qdrant {
&self,
f: impl Fn(CollectionsClient<InterceptedService<Channel, TokenInterceptor>>) -> O,
) -> QdrantResult<T> {
if self.config.check_compatibility && self.is_compatible() == None {
let client_version = env!("CARGO_PKG_VERSION").to_string();
let server_version = match self.health_check().await {
Ok(info) => info.version,
Err(_) => "Unknown".to_string(),
};
if server_version == "Unknown" {
println!("Failed to obtain server version. \
Unable to check client-server compatibility. \
Set check_compatibility=false to skip version check.");
} else {
let is_compatible = is_compatible(Some(&client_version), Some(&server_version));
self.set_is_compatible(Some(is_compatible));
println!("Client version {client_version} is not compatible with server version {server_version}. \
Major versions should match and minor version difference must not exceed 1. \
Set check_compatibility=false to skip version check.");

}
}

let result = self
.channel
.with_channel(
Expand All @@ -39,6 +53,7 @@ impl Qdrant {
.send_compressed(compression.into())
.accept_compressed(compression.into());
}
// let res = client.get(&HealthCheckRequest {}).await?;
f(client)
},
false,
Expand Down
6 changes: 6 additions & 0 deletions src/qdrant_client/config.rs
Original file line number Diff line number Diff line change
Expand Up @@ -173,6 +173,12 @@ impl QdrantConfig {
pub fn build(self) -> Result<Qdrant, QdrantError> {
Qdrant::new(self)
}

pub fn skip_compatibility_check(mut self) -> Self {
self.check_compatibility = false;
self
}

}

/// Default Qdrant client configuration.
Expand Down
27 changes: 14 additions & 13 deletions src/qdrant_client/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -10,10 +10,10 @@ mod query;
mod search;
mod sharding_keys;
mod snapshot;
// mod version_check;
mod version_check;

use std::cell::{RefCell};
use std::future::Future;
// use tokio::runtime::Runtime;
use tonic::codegen::InterceptedService;
use tonic::transport::{Channel, Uri};
use tonic::Status;
Expand All @@ -22,7 +22,6 @@ use crate::auth::TokenInterceptor;
use crate::channel_pool::ChannelPool;
use crate::qdrant::{qdrant_client, HealthCheckReply, HealthCheckRequest};
use crate::qdrant_client::config::QdrantConfig;
// use crate::qdrant_client::version_check::health_check_sync;
use crate::QdrantError;

/// [`Qdrant`] client result
Expand Down Expand Up @@ -87,6 +86,9 @@ pub struct Qdrant {

/// Internal connection pool
channel: ChannelPool,

/// Internal flag for checking compatibility with the server
is_compatible: RefCell<Option<bool>>,
}

/// # Construct and connect
Expand All @@ -105,20 +107,19 @@ impl Qdrant {
config.keep_alive_while_idle,
);

let client = Self { channel, config };

if client.config.check_compatibility {
let client_version = env!("CARGO_PKG_VERSION").to_string();
// let rt = Runtime::new()?;
// let server_version = rt.block_on(client.health_check())?;
let server_version = tokio::runtime::Handle::current().block_on(client.health_check())?;
println!("Connected to Qdrant version: {}", server_version.version);
println!("Qdrant client version: {}", client_version);
}
let client = Self { channel, config, is_compatible: RefCell::new(None) };

Ok(client)
}

fn set_is_compatible(&self, value: Option<bool>) {
*self.is_compatible.borrow_mut() = value;
}

fn is_compatible(&self) -> Option<bool> {
*self.is_compatible.borrow()
}

/// Build a new Qdrant client with the given URL.
///
/// ```no_run
Expand Down
21 changes: 21 additions & 0 deletions src/qdrant_client/points.rs
Original file line number Diff line number Diff line change
Expand Up @@ -13,6 +13,7 @@ use crate::qdrant::{
UpdateBatchResponse, UpdatePointVectors, UpsertPoints,
};
use crate::qdrant_client::{Qdrant, QdrantResult};
use crate::qdrant_client::version_check::is_compatible;

/// # Point operations
///
Expand All @@ -24,6 +25,26 @@ impl Qdrant {
&self,
f: impl Fn(PointsClient<InterceptedService<Channel, TokenInterceptor>>) -> O,
) -> QdrantResult<T> {
if self.config.check_compatibility && self.is_compatible() == None {
let client_version = env!("CARGO_PKG_VERSION").to_string();
let server_version = match self.health_check().await {
Ok(info) => info.version,
Err(_) => "Unknown".to_string(),
};
if server_version == "Unknown" {
println!("Failed to obtain server version. \
Unable to check client-server compatibility. \
Set check_compatibility=false to skip version check.");
} else {
let is_compatible = is_compatible(Some(&client_version), Some(&server_version));
self.set_is_compatible(Some(is_compatible));
println!("Client version {client_version} is not compatible with server version {server_version}. \
Major versions should match and minor version difference must not exceed 1. \
Set check_compatibility=false to skip version check.");

}
}

let result = self
.channel
.with_channel(
Expand Down
21 changes: 21 additions & 0 deletions src/qdrant_client/snapshot.rs
Original file line number Diff line number Diff line change
Expand Up @@ -12,6 +12,7 @@ use crate::qdrant::{
ListFullSnapshotsRequest, ListSnapshotsRequest, ListSnapshotsResponse,
};
use crate::qdrant_client::{Qdrant, QdrantResult};
use crate::qdrant_client::version_check::is_compatible;

/// # Snapshot operations
///
Expand All @@ -23,6 +24,26 @@ impl Qdrant {
&self,
f: impl Fn(SnapshotsClient<InterceptedService<Channel, TokenInterceptor>>) -> O,
) -> QdrantResult<T> {
if self.config.check_compatibility && self.is_compatible() == None {
let client_version = env!("CARGO_PKG_VERSION").to_string();
let server_version = match self.health_check().await {
Ok(info) => info.version,
Err(_) => "Unknown".to_string(),
};
if server_version == "Unknown" {
println!("Failed to obtain server version. \
Unable to check client-server compatibility. \
Set check_compatibility=false to skip version check.");
} else {
let is_compatible = is_compatible(Some(&client_version), Some(&server_version));
self.set_is_compatible(Some(is_compatible));
println!("Client version {client_version} is not compatible with server version {server_version}. \
Major versions should match and minor version difference must not exceed 1. \
Set check_compatibility=false to skip version check.");

}
}

let result = self
.channel
.with_channel(
Expand Down
79 changes: 79 additions & 0 deletions src/qdrant_client/version_check.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,79 @@
use std::error::Error;
use std::fmt;

#[derive(Debug, Clone)]
pub struct Version {
pub major: u32,
pub minor: u32
}

impl Version {
pub fn parse(version: &str) -> Result<Version, VersionParseError> {
if version.is_empty() {
return Err(VersionParseError::EmptyVersion);
}
let parts: Vec<&str> = version.split('.').collect();
if parts.len() < 2 {
return Err(VersionParseError::InvalidFormat(version.to_string()));
}

let major = parts[0]
.parse::<u32>()
.map_err(|_| VersionParseError::InvalidFormat(version.to_string()))?;
let minor = parts[1]
.parse::<u32>()
.map_err(|_| VersionParseError::InvalidFormat(version.to_string()))?;

Ok(Version { major, minor })
}
}

#[derive(Debug)]
pub enum VersionParseError {
EmptyVersion,
InvalidFormat(String),
}

impl fmt::Display for VersionParseError {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
match self {
VersionParseError::EmptyVersion => write!(f, "Version is empty"),
VersionParseError::InvalidFormat(version) => {
write!(f, "Unable to parse version, expected format: x.y[.z], found: {}", version)
}
}
}
}

impl Error for VersionParseError {}

pub fn is_compatible(client_version: Option<&str>, server_version: Option<&str>) -> bool {
if client_version.is_none() || server_version.is_none() {
println!(
"Unable to compare versions, client_version: {:?}, server_version: {:?}",
client_version, server_version
);
return false;
}

let client_version = client_version.unwrap();
let server_version = server_version.unwrap();

if client_version == server_version {
return true;
}

match (Version::parse(client_version), Version::parse(server_version)) {
(Ok(client), Ok(server)) => {
let major_dif = (client.major as i32 - server.major as i32).abs();
if major_dif >= 1 {
return false;
}
(client.minor as i32 - server.minor as i32).abs() <= 1
}
(Err(e), _) | (_, Err(e)) => {
println!("Unable to compare versions: {}", e);
false
}
}
}

0 comments on commit 80fb4fa

Please sign in to comment.