From f69b533a44569a029ace173ee57e99f70361a651 Mon Sep 17 00:00:00 2001 From: AN Long Date: Sun, 24 Aug 2025 22:24:53 +0900 Subject: [PATCH] =?UTF-8?q?=E2=9C=A8=20Allow=20get=20with=20cas=20token=20?= =?UTF-8?q?with=20various=20types?= MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit --- README.md | 5 +++-- src/client.rs | 7 +++---- src/protocol/ascii.rs | 12 +++++++++--- src/value.rs | 30 ++++++++++++++++++++++++++++++ tests/test_ascii.rs | 37 +++++++++++++++++++++++++++++++++++++ tests/tests.rs | 21 +++++++++++++++++++-- 6 files changed, 101 insertions(+), 11 deletions(-) diff --git a/README.md b/README.md index 91bc44c..93820a5 100644 --- a/README.md +++ b/README.md @@ -60,8 +60,9 @@ let value: String = client.get("foo").unwrap().unwrap(); assert_eq!(value, "foobarbaz"); // cas(check and set): -let (_, _, cas_token) = client.get("foo").unwrap().unwrap(); -let cas_id = cas_id.unwrap(); +let (value, _flags, cas_token): (String, u32, Option) = client.get("foo").unwrap().unwrap(); +assert_eq!(value, "foobarbaz"); +let cas_id = cas_token.unwrap(); client.cas("foo", "qux", 0, cas_id).unwrap(); // delete value: diff --git a/src/client.rs b/src/client.rs index e903068..e77ef97 100644 --- a/src/client.rs +++ b/src/client.rs @@ -280,7 +280,7 @@ impl Client { /// ```rust /// let client = memcache::Client::connect("memcache://localhost:12345").unwrap(); /// client.set("foo", "bar", 10).unwrap(); - /// # client.flush().unwrap(); + /// client.flush().unwrap(); /// ``` pub fn set>(&self, key: &str, value: V, expiration: u32) -> Result<(), MemcacheError> { check_key_len(key)?; @@ -293,13 +293,12 @@ impl Client { /// Example: /// /// ```rust - /// use std::collections::HashMap; /// let client = memcache::Client::connect("memcache://localhost:12345").unwrap(); /// client.set("foo", "bar", 10).unwrap(); - /// let (_, _, cas) = client.get("foo").unwrap().unwrap(); + /// let (_, _, cas): (String, u32, Option) = client.get("foo").unwrap().unwrap(); /// let cas = cas.unwrap(); /// assert_eq!(true, client.cas("foo", "bar2", 10, cas).unwrap()); - /// # client.flush().unwrap(); + /// client.flush().unwrap(); /// ``` pub fn cas>( &self, diff --git a/src/protocol/ascii.rs b/src/protocol/ascii.rs index f92d55d..24e94d2 100644 --- a/src/protocol/ascii.rs +++ b/src/protocol/ascii.rs @@ -155,14 +155,20 @@ impl ProtocolTrait for AsciiProtocol { } fn get(&mut self, key: &str) -> Result, MemcacheError> { - write!(self.reader.get_mut(), "get {}\r\n", key)?; + let (command, has_cas) = if V::requires_cas() { + ("gets", true) + } else { + ("get", false) + }; + + write!(self.reader.get_mut(), "{} {}\r\n", command, key)?; - if let Some((k, v)) = self.parse_get_response(false)? { + if let Some((k, v)) = self.parse_get_response(has_cas)? { if k != key { Err(ServerError::BadResponse(Cow::Borrowed( "key doesn't match in the response", )))? - } else if self.parse_get_response::(false)?.is_none() { + } else if self.parse_get_response::(has_cas)?.is_none() { Ok(Some(v)) } else { Err(ServerError::BadResponse(Cow::Borrowed("Expected end of get response")))? diff --git a/src/value.rs b/src/value.rs index 86865c2..4bd6fcf 100644 --- a/src/value.rs +++ b/src/value.rs @@ -122,6 +122,11 @@ pub trait FromMemcacheValue: Sized { pub trait FromMemcacheValueExt: Sized { fn from_memcache_value(value: Vec, flags: u32, cas: Option) -> MemcacheValue; + + /// Returns whether this type requires CAS token + fn requires_cas() -> bool { + false + } } impl FromMemcacheValueExt for V { @@ -134,6 +139,10 @@ impl FromMemcacheValueExt for (Vec, u32, Option) { fn from_memcache_value(value: Vec, flags: u32, cas: Option) -> MemcacheValue { return Ok((value, flags, cas)); } + + fn requires_cas() -> bool { + true + } } impl FromMemcacheValue for (Vec, u32) { @@ -148,6 +157,16 @@ impl FromMemcacheValue for Vec { } } +impl FromMemcacheValueExt for (String, u32, Option) { + fn from_memcache_value(value: Vec, flags: u32, cas: Option) -> MemcacheValue { + return Ok((String::from_utf8(value)?, flags, cas)); + } + + fn requires_cas() -> bool { + true + } +} + impl FromMemcacheValue for (String, u32) { fn from_memcache_value(value: Vec, flags: u32) -> MemcacheValue { return Ok((String::from_utf8(value)?, flags)); @@ -175,6 +194,17 @@ macro_rules! impl_from_memcache_value_for_number { Ok(($ty::from_str(s.as_str())?, flags)) } } + + impl FromMemcacheValueExt for ($ty, u32, Option) { + fn from_memcache_value(value: Vec, flags: u32, cas: Option) -> MemcacheValue { + let s: String = FromMemcacheValue::from_memcache_value(value, 0)?; + Ok(($ty::from_str(s.as_str())?, flags, cas)) + } + + fn requires_cas() -> bool { + true + } + } }; } diff --git a/tests/test_ascii.rs b/tests/test_ascii.rs index 431b66d..9249a32 100644 --- a/tests/test_ascii.rs +++ b/tests/test_ascii.rs @@ -87,6 +87,43 @@ fn test_get_with_flags() { // Test our new FromMemcacheValue implementation let value: Option<(String, u32)> = client.get("test_key").unwrap(); assert_eq!(value, Some(("test_value".to_string(), 114514))); +} + +#[test] +fn test_get_with_cas() { + let client = memcache::connect("memcache://localhost:12345?protocol=ascii").unwrap(); + client.flush().unwrap(); + + // Set a value + client.set("test_key", "test_value", 0).unwrap(); + // Test get with CAS token + let value: Option<(String, u32, Option)> = client.get("test_key").unwrap(); + let (value_str, _flags, cas) = value.unwrap(); + assert_eq!(value_str, "test_value"); + assert!(cas.is_some(), "CAS token should be present"); +} + +#[test] +fn test_cas() { + let client = memcache::Client::connect("memcache://localhost:12345?protocol=ascii").unwrap(); client.flush().unwrap(); + + // Test using get with CAS token for cas operation + client.set("test_cas_key", "initial_value", 0).unwrap(); + let cas_value: Option<(String, u32, Option)> = client.get("test_cas_key").unwrap(); + let (_, _, cas_token) = cas_value.unwrap(); + assert!(cas_token.is_some(), "CAS token should be present from get"); + + // Use the CAS token from get to update the value + assert_eq!( + true, + client + .cas("test_cas_key", "updated_value", 0, cas_token.unwrap()) + .unwrap() + ); + + // Verify the update worked + let updated_value: Option = client.get("test_cas_key").unwrap(); + assert_eq!(updated_value, Some("updated_value".into())); } diff --git a/tests/tests.rs b/tests/tests.rs index e8096c5..a4a43bc 100644 --- a/tests/tests.rs +++ b/tests/tests.rs @@ -247,6 +247,25 @@ fn test_cas() { .cas("not_exists_key", "bar", 0, ascii_foo_value.2.unwrap()) .unwrap() ); + + // Test using get with CAS token for cas operation + client.set("test_cas_key", "initial_value", 0).unwrap(); + let cas_value: Option<(String, u32, Option)> = client.get("test_cas_key").unwrap(); + let (_, _, cas_token) = cas_value.unwrap(); + assert!(cas_token.is_some(), "CAS token should be present from get"); + + // Use the CAS token from get to update the value + assert_eq!( + true, + client + .cas("test_cas_key", "updated_value", 0, cas_token.unwrap()) + .unwrap() + ); + + // Verify the update worked + let updated_value: Option = client.get("test_cas_key").unwrap(); + assert_eq!(updated_value, Some("updated_value".into())); + client.flush().unwrap(); } } @@ -291,6 +310,4 @@ fn test_get_with_flags() { // Test our new FromMemcacheValue implementation let value: Option<(String, u32)> = client.get("test_key").unwrap(); assert_eq!(value, Some(("test_value".to_string(), 114514))); - - client.flush().unwrap(); }