From 80fb4fa88ecec90ae3a72ead76ab33fa43a3777e Mon Sep 17 00:00:00 2001 From: tellet-q Date: Thu, 2 Jan 2025 11:56:10 +0100 Subject: [PATCH] WIP --- src/qdrant_client/collection.rs | 31 +++++++++--- src/qdrant_client/config.rs | 6 +++ src/qdrant_client/mod.rs | 27 +++++----- src/qdrant_client/points.rs | 21 ++++++++ src/qdrant_client/snapshot.rs | 21 ++++++++ src/qdrant_client/version_check.rs | 79 ++++++++++++++++++++++++++++++ 6 files changed, 164 insertions(+), 21 deletions(-) create mode 100644 src/qdrant_client/version_check.rs diff --git a/src/qdrant_client/collection.rs b/src/qdrant_client/collection.rs index 1bb1527c..8c9606a6 100644 --- a/src/qdrant_client/collection.rs +++ b/src/qdrant_client/collection.rs @@ -6,15 +6,9 @@ use tonic::Status; use crate::auth::TokenInterceptor; use crate::qdrant::collections_client::CollectionsClient; -use crate::qdrant::{ - alias_operations, AliasOperations, ChangeAliases, CollectionClusterInfoRequest, - CollectionClusterInfoResponse, CollectionExistsRequest, CollectionOperationResponse, - CreateAlias, CreateCollection, DeleteAlias, DeleteCollection, GetCollectionInfoRequest, - GetCollectionInfoResponse, ListAliasesRequest, ListAliasesResponse, - ListCollectionAliasesRequest, ListCollectionsRequest, ListCollectionsResponse, RenameAlias, - UpdateCollection, UpdateCollectionClusterSetupRequest, UpdateCollectionClusterSetupResponse, -}; +use crate::qdrant::{alias_operations, AliasOperations, ChangeAliases, CollectionClusterInfoRequest, CollectionClusterInfoResponse, CollectionExistsRequest, CollectionOperationResponse, CreateAlias, CreateCollection, DeleteAlias, DeleteCollection, GetCollectionInfoRequest, GetCollectionInfoResponse, ListAliasesRequest, ListAliasesResponse, ListCollectionAliasesRequest, ListCollectionsRequest, ListCollectionsResponse, RenameAlias, UpdateCollection, UpdateCollectionClusterSetupRequest, UpdateCollectionClusterSetupResponse}; use crate::qdrant_client::{Qdrant, QdrantResult}; +use crate::qdrant_client::version_check::is_compatible; /// # Collection operations /// @@ -27,6 +21,26 @@ impl Qdrant { &self, f: impl Fn(CollectionsClient>) -> O, ) -> QdrantResult { + if self.config.check_compatibility && self.is_compatible() == None { + let client_version = env!("CARGO_PKG_VERSION").to_string(); + let server_version = match self.health_check().await { + Ok(info) => info.version, + Err(_) => "Unknown".to_string(), + }; + if server_version == "Unknown" { + println!("Failed to obtain server version. \ + Unable to check client-server compatibility. \ + Set check_compatibility=false to skip version check."); + } else { + let is_compatible = is_compatible(Some(&client_version), Some(&server_version)); + self.set_is_compatible(Some(is_compatible)); + println!("Client version {client_version} is not compatible with server version {server_version}. \ + Major versions should match and minor version difference must not exceed 1. \ + Set check_compatibility=false to skip version check."); + + } + } + let result = self .channel .with_channel( @@ -39,6 +53,7 @@ impl Qdrant { .send_compressed(compression.into()) .accept_compressed(compression.into()); } + // let res = client.get(&HealthCheckRequest {}).await?; f(client) }, false, diff --git a/src/qdrant_client/config.rs b/src/qdrant_client/config.rs index 6cede67e..549125d1 100644 --- a/src/qdrant_client/config.rs +++ b/src/qdrant_client/config.rs @@ -173,6 +173,12 @@ impl QdrantConfig { pub fn build(self) -> Result { Qdrant::new(self) } + + pub fn skip_compatibility_check(mut self) -> Self { + self.check_compatibility = false; + self + } + } /// Default Qdrant client configuration. diff --git a/src/qdrant_client/mod.rs b/src/qdrant_client/mod.rs index 6dd3552a..286bcc20 100644 --- a/src/qdrant_client/mod.rs +++ b/src/qdrant_client/mod.rs @@ -10,10 +10,10 @@ mod query; mod search; mod sharding_keys; mod snapshot; -// mod version_check; +mod version_check; +use std::cell::{RefCell}; use std::future::Future; -// use tokio::runtime::Runtime; use tonic::codegen::InterceptedService; use tonic::transport::{Channel, Uri}; use tonic::Status; @@ -22,7 +22,6 @@ use crate::auth::TokenInterceptor; use crate::channel_pool::ChannelPool; use crate::qdrant::{qdrant_client, HealthCheckReply, HealthCheckRequest}; use crate::qdrant_client::config::QdrantConfig; -// use crate::qdrant_client::version_check::health_check_sync; use crate::QdrantError; /// [`Qdrant`] client result @@ -87,6 +86,9 @@ pub struct Qdrant { /// Internal connection pool channel: ChannelPool, + + /// Internal flag for checking compatibility with the server + is_compatible: RefCell>, } /// # Construct and connect @@ -105,20 +107,19 @@ impl Qdrant { config.keep_alive_while_idle, ); - let client = Self { channel, config }; - - if client.config.check_compatibility { - let client_version = env!("CARGO_PKG_VERSION").to_string(); - // let rt = Runtime::new()?; - // let server_version = rt.block_on(client.health_check())?; - let server_version = tokio::runtime::Handle::current().block_on(client.health_check())?; - println!("Connected to Qdrant version: {}", server_version.version); - println!("Qdrant client version: {}", client_version); - } + let client = Self { channel, config, is_compatible: RefCell::new(None) }; Ok(client) } + fn set_is_compatible(&self, value: Option) { + *self.is_compatible.borrow_mut() = value; + } + + fn is_compatible(&self) -> Option { + *self.is_compatible.borrow() + } + /// Build a new Qdrant client with the given URL. /// /// ```no_run diff --git a/src/qdrant_client/points.rs b/src/qdrant_client/points.rs index 023f84e0..0a10449c 100644 --- a/src/qdrant_client/points.rs +++ b/src/qdrant_client/points.rs @@ -13,6 +13,7 @@ use crate::qdrant::{ UpdateBatchResponse, UpdatePointVectors, UpsertPoints, }; use crate::qdrant_client::{Qdrant, QdrantResult}; +use crate::qdrant_client::version_check::is_compatible; /// # Point operations /// @@ -24,6 +25,26 @@ impl Qdrant { &self, f: impl Fn(PointsClient>) -> O, ) -> QdrantResult { + if self.config.check_compatibility && self.is_compatible() == None { + let client_version = env!("CARGO_PKG_VERSION").to_string(); + let server_version = match self.health_check().await { + Ok(info) => info.version, + Err(_) => "Unknown".to_string(), + }; + if server_version == "Unknown" { + println!("Failed to obtain server version. \ + Unable to check client-server compatibility. \ + Set check_compatibility=false to skip version check."); + } else { + let is_compatible = is_compatible(Some(&client_version), Some(&server_version)); + self.set_is_compatible(Some(is_compatible)); + println!("Client version {client_version} is not compatible with server version {server_version}. \ + Major versions should match and minor version difference must not exceed 1. \ + Set check_compatibility=false to skip version check."); + + } + } + let result = self .channel .with_channel( diff --git a/src/qdrant_client/snapshot.rs b/src/qdrant_client/snapshot.rs index 05b32a86..5067a7f8 100644 --- a/src/qdrant_client/snapshot.rs +++ b/src/qdrant_client/snapshot.rs @@ -12,6 +12,7 @@ use crate::qdrant::{ ListFullSnapshotsRequest, ListSnapshotsRequest, ListSnapshotsResponse, }; use crate::qdrant_client::{Qdrant, QdrantResult}; +use crate::qdrant_client::version_check::is_compatible; /// # Snapshot operations /// @@ -23,6 +24,26 @@ impl Qdrant { &self, f: impl Fn(SnapshotsClient>) -> O, ) -> QdrantResult { + if self.config.check_compatibility && self.is_compatible() == None { + let client_version = env!("CARGO_PKG_VERSION").to_string(); + let server_version = match self.health_check().await { + Ok(info) => info.version, + Err(_) => "Unknown".to_string(), + }; + if server_version == "Unknown" { + println!("Failed to obtain server version. \ + Unable to check client-server compatibility. \ + Set check_compatibility=false to skip version check."); + } else { + let is_compatible = is_compatible(Some(&client_version), Some(&server_version)); + self.set_is_compatible(Some(is_compatible)); + println!("Client version {client_version} is not compatible with server version {server_version}. \ + Major versions should match and minor version difference must not exceed 1. \ + Set check_compatibility=false to skip version check."); + + } + } + let result = self .channel .with_channel( diff --git a/src/qdrant_client/version_check.rs b/src/qdrant_client/version_check.rs new file mode 100644 index 00000000..20a4383e --- /dev/null +++ b/src/qdrant_client/version_check.rs @@ -0,0 +1,79 @@ +use std::error::Error; +use std::fmt; + +#[derive(Debug, Clone)] +pub struct Version { + pub major: u32, + pub minor: u32 +} + +impl Version { + pub fn parse(version: &str) -> Result { + if version.is_empty() { + return Err(VersionParseError::EmptyVersion); + } + let parts: Vec<&str> = version.split('.').collect(); + if parts.len() < 2 { + return Err(VersionParseError::InvalidFormat(version.to_string())); + } + + let major = parts[0] + .parse::() + .map_err(|_| VersionParseError::InvalidFormat(version.to_string()))?; + let minor = parts[1] + .parse::() + .map_err(|_| VersionParseError::InvalidFormat(version.to_string()))?; + + Ok(Version { major, minor }) + } +} + +#[derive(Debug)] +pub enum VersionParseError { + EmptyVersion, + InvalidFormat(String), +} + +impl fmt::Display for VersionParseError { + fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result { + match self { + VersionParseError::EmptyVersion => write!(f, "Version is empty"), + VersionParseError::InvalidFormat(version) => { + write!(f, "Unable to parse version, expected format: x.y[.z], found: {}", version) + } + } + } +} + +impl Error for VersionParseError {} + +pub fn is_compatible(client_version: Option<&str>, server_version: Option<&str>) -> bool { + if client_version.is_none() || server_version.is_none() { + println!( + "Unable to compare versions, client_version: {:?}, server_version: {:?}", + client_version, server_version + ); + return false; + } + + let client_version = client_version.unwrap(); + let server_version = server_version.unwrap(); + + if client_version == server_version { + return true; + } + + match (Version::parse(client_version), Version::parse(server_version)) { + (Ok(client), Ok(server)) => { + let major_dif = (client.major as i32 - server.major as i32).abs(); + if major_dif >= 1 { + return false; + } + (client.minor as i32 - server.minor as i32).abs() <= 1 + } + (Err(e), _) | (_, Err(e)) => { + println!("Unable to compare versions: {}", e); + false + } + } +}