diff --git a/Cargo.lock b/Cargo.lock index 047e704b..5d40db12 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -1485,7 +1485,11 @@ source = "registry+https://github.com/rust-lang/crates.io-index" checksum = "ba5a308b75df32fe02788e748662718f03fde005016435c444eea572398219fd" dependencies = [ "bytes", + "futures-core", "memchr", + "pin-project-lite", + "tokio", + "tokio-util", ] [[package]] @@ -6538,6 +6542,30 @@ dependencies = [ "yasna", ] +[[package]] +name = "redis" +version = "0.25.4" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "e0d7a6955c7511f60f3ba9e86c6d02b3c3f144f8c24b288d1f4e18074ab8bbec" +dependencies = [ + "arc-swap", + "async-trait", + "bytes", + "combine", + "futures", + "futures-util", + "itoa", + "percent-encoding", + "pin-project-lite", + "ryu", + "sha1_smol", + "socket2 0.5.10", + "tokio", + "tokio-retry", + "tokio-util", + "url", +] + [[package]] name = "redox_syscall" version = "0.5.4" @@ -7501,6 +7529,12 @@ dependencies = [ "digest", ] +[[package]] +name = "sha1_smol" +version = "1.0.1" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "bbfa15b3dddfee50a0fff136974b3e1bde555604ba463834a7eb7deb6417705d" + [[package]] name = "sha2" version = "0.10.8" @@ -8599,6 +8633,17 @@ dependencies = [ "syn 2.0.106", ] +[[package]] +name = "tokio-retry" +version = "0.3.0" +source = "registry+https://github.com/rust-lang/crates.io-index" +checksum = "7f57eb36ecbe0fc510036adff84824dd3c24bb781e21bfa67b69d556aa85214f" +dependencies = [ + "pin-project", + "rand 0.8.5", + "tokio", +] + [[package]] name = "tokio-rustls" version = "0.23.4" @@ -8933,10 +8978,14 @@ name = "torii-cache" version = "1.8.10" dependencies = [ "async-trait", + "bincode", "dashmap 6.1.0", "dojo-types", "dojo-world", + "redis", + "serde", "serde_json", + "sha2", "sqlx", "starknet", "thiserror 1.0.63", @@ -9482,6 +9531,7 @@ dependencies = [ "anyhow", "async-trait", "base64 0.21.7", + "bincode", "bitflags 2.9.1", "cainome", "chrono", diff --git a/Cargo.toml b/Cargo.toml index c6fd739c..9f8e54f2 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -179,6 +179,9 @@ starknet-core = "0.16.0" dashmap = "6.1.0" +redis = { version = "0.25", features = ["tokio-comp", "connection-manager"] } +sha2 = "0.10" +bincode = "1.3" # [patch.crates-io] # cainome = { git = "https://github.com/Larkooo/cainome", branch = "patch-1" } diff --git a/crates/cache/Cargo.toml b/crates/cache/Cargo.toml index 3e30ab25..92e93e73 100644 --- a/crates/cache/Cargo.toml +++ b/crates/cache/Cargo.toml @@ -4,17 +4,25 @@ edition.workspace = true repository.workspace = true version.workspace = true +[features] +default = [] +redis = ["dep:redis"] + [dependencies] async-trait.workspace = true +bincode.workspace = true dashmap.workspace = true dojo-types.workspace = true dojo-world.workspace = true +redis = { workspace = true, optional = true } +serde.workspace = true +serde_json.workspace = true +sha2.workspace = true sqlx.workspace = true starknet.workspace = true thiserror.workspace = true tokio.workspace = true torii-math.workspace = true +torii-proto.workspace = true torii-sqlite-types.workspace = true -serde_json.workspace = true torii-storage.workspace = true -torii-proto.workspace = true diff --git a/crates/cache/src/lib.rs b/crates/cache/src/lib.rs index 5dd4c36c..9a331f22 100644 --- a/crates/cache/src/lib.rs +++ b/crates/cache/src/lib.rs @@ -18,6 +18,7 @@ use torii_storage::ReadOnlyStorage; use crate::error::Error; pub mod error; +pub mod query_cache; pub type CacheError = Error; diff --git a/crates/cache/src/query_cache/cached_row.rs b/crates/cache/src/query_cache/cached_row.rs new file mode 100644 index 00000000..e2bc46f6 --- /dev/null +++ b/crates/cache/src/query_cache/cached_row.rs @@ -0,0 +1,177 @@ +use serde::{Deserialize, Serialize}; +use sqlx::sqlite::SqliteRow; +use sqlx::{Column, Row, TypeInfo}; + +/// A serializable representation of a SQLite row. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CachedRow { + /// Column metadata. + pub columns: Vec, + /// Row values. + pub values: Vec, +} + +/// Column metadata for cached rows. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CachedColumn { + /// Column name. + pub name: String, + /// SQLite type name. + pub type_name: String, +} + +/// Cached value types matching SQLite's type affinity. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub enum CachedValue { + /// NULL value. + Null, + /// INTEGER value (i64). + Integer(i64), + /// REAL value (f64). + Real(f64), + /// TEXT value. + Text(String), + /// BLOB value. + Blob(Vec), +} + +impl CachedRow { + /// Create a CachedRow from a SQLite row. + pub fn from_sqlite_row(row: &SqliteRow) -> Self { + let columns: Vec = row + .columns() + .iter() + .map(|c| CachedColumn { + name: c.name().to_string(), + type_name: c.type_info().name().to_string(), + }) + .collect(); + + let values: Vec = (0..columns.len()) + .map(|i| { + // Try each type in order based on SQLite type affinity + if let Ok(v) = row.try_get::, _>(i) { + match v { + Some(n) => CachedValue::Integer(n), + None => CachedValue::Null, + } + } else if let Ok(v) = row.try_get::, _>(i) { + match v { + Some(n) => CachedValue::Real(n), + None => CachedValue::Null, + } + } else if let Ok(v) = row.try_get::, _>(i) { + match v { + Some(s) => CachedValue::Text(s), + None => CachedValue::Null, + } + } else if let Ok(v) = row.try_get::>, _>(i) { + match v { + Some(b) => CachedValue::Blob(b), + None => CachedValue::Null, + } + } else { + CachedValue::Null + } + }) + .collect(); + + Self { columns, values } + } + + /// Get a value by column name. + pub fn get(&self, column_name: &str) -> Option<&CachedValue> { + let idx = self.columns.iter().position(|c| c.name == column_name)?; + self.values.get(idx) + } + + /// Get an integer value by column name. + pub fn get_i64(&self, column_name: &str) -> Option { + match self.get(column_name)? { + CachedValue::Integer(v) => Some(*v), + _ => None, + } + } + + /// Get a string value by column name. + pub fn get_string(&self, column_name: &str) -> Option<&str> { + match self.get(column_name)? { + CachedValue::Text(v) => Some(v), + _ => None, + } + } + + /// Get a blob value by column name. + pub fn get_blob(&self, column_name: &str) -> Option<&[u8]> { + match self.get(column_name)? { + CachedValue::Blob(v) => Some(v), + _ => None, + } + } +} + +impl CachedValue { + /// Check if the value is null. + pub fn is_null(&self) -> bool { + matches!(self, CachedValue::Null) + } + + /// Convert to i64 if possible. + pub fn as_i64(&self) -> Option { + match self { + CachedValue::Integer(v) => Some(*v), + _ => None, + } + } + + /// Convert to f64 if possible. + pub fn as_f64(&self) -> Option { + match self { + CachedValue::Real(v) => Some(*v), + CachedValue::Integer(v) => Some(*v as f64), + _ => None, + } + } + + /// Convert to string if possible. + pub fn as_str(&self) -> Option<&str> { + match self { + CachedValue::Text(v) => Some(v), + _ => None, + } + } + + /// Convert to bytes if possible. + pub fn as_bytes(&self) -> Option<&[u8]> { + match self { + CachedValue::Blob(v) => Some(v), + _ => None, + } + } +} + +/// A cached page of query results with optional cursor for pagination. +#[derive(Debug, Clone, Serialize, Deserialize)] +pub struct CachedPage { + /// The rows in this page. + pub rows: Vec, + /// Optional cursor for the next page. + pub next_cursor: Option, +} + +impl CachedPage { + /// Create a new cached page. + pub fn new(rows: Vec, next_cursor: Option) -> Self { + Self { rows, next_cursor } + } + + /// Check if this page is empty. + pub fn is_empty(&self) -> bool { + self.rows.is_empty() + } + + /// Get the number of rows in this page. + pub fn len(&self) -> usize { + self.rows.len() + } +} diff --git a/crates/cache/src/query_cache/in_memory.rs b/crates/cache/src/query_cache/in_memory.rs new file mode 100644 index 00000000..12768d56 --- /dev/null +++ b/crates/cache/src/query_cache/in_memory.rs @@ -0,0 +1,180 @@ +use async_trait::async_trait; +use dashmap::DashMap; +use std::time::{Duration, Instant}; + +use super::{CacheResult, QueryCache, QueryCacheConfig, QueryCacheError}; + +/// A cached entry with expiration time. +#[derive(Debug)] +struct CacheEntry { + data: Vec, + expires_at: Instant, +} + +/// In-memory query cache implementation using DashMap for thread-safety. +#[derive(Debug)] +pub struct InMemoryQueryCache { + config: QueryCacheConfig, + entries: DashMap, +} + +impl InMemoryQueryCache { + /// Create a new in-memory cache with the given configuration. + pub fn new(config: QueryCacheConfig) -> Self { + Self { config, entries: DashMap::new() } + } + + /// Get the number of entries in the cache. + pub fn len(&self) -> usize { + self.entries.len() + } + + /// Check if the cache is empty. + pub fn is_empty(&self) -> bool { + self.entries.is_empty() + } + + /// Clear all entries from the cache. + pub fn clear(&self) { + self.entries.clear(); + } + + /// Remove expired entries from the cache. + pub fn evict_expired(&self) { + let now = Instant::now(); + self.entries.retain(|_, entry| entry.expires_at > now); + } +} + +#[async_trait] +impl QueryCache for InMemoryQueryCache { + async fn get(&self, key: &str) -> CacheResult> { + if !self.config.enabled { + return CacheResult::Disabled; + } + + if let Some(entry) = self.entries.get(key) { + if entry.expires_at > Instant::now() { + return CacheResult::Hit(entry.data.clone()); + } + // Entry expired, remove it + drop(entry); + self.entries.remove(key); + } + CacheResult::Miss + } + + async fn set(&self, key: &str, value: &[u8]) -> Result<(), QueryCacheError> { + if !self.config.enabled { + return Ok(()); + } + + self.entries.insert( + key.to_string(), + CacheEntry { + data: value.to_vec(), + expires_at: Instant::now() + Duration::from_secs(self.config.ttl_seconds), + }, + ); + Ok(()) + } + + async fn invalidate(&self, key: &str) -> Result<(), QueryCacheError> { + self.entries.remove(key); + Ok(()) + } + + async fn invalidate_pattern(&self, pattern: &str) -> Result<(), QueryCacheError> { + // Pattern is expected to end with "*" for prefix matching + let prefix = pattern.trim_end_matches('*'); + self.entries.retain(|k, _| !k.starts_with(prefix)); + Ok(()) + } + + fn is_available(&self) -> bool { + true + } + + fn config(&self) -> &QueryCacheConfig { + &self.config + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[tokio::test] + async fn test_cache_disabled() { + let config = QueryCacheConfig { enabled: false, ttl_seconds: 60 }; + let cache = InMemoryQueryCache::new(config); + + // Set should succeed but not actually store + cache.set("key", b"value").await.unwrap(); + assert!(cache.is_empty()); + + // Get should return Disabled + let result = cache.get("key").await; + assert!(matches!(result, CacheResult::Disabled)); + } + + #[tokio::test] + async fn test_cache_hit_miss() { + let config = QueryCacheConfig { enabled: true, ttl_seconds: 60 }; + let cache = InMemoryQueryCache::new(config); + + // Miss on empty cache + let result = cache.get("key").await; + assert!(matches!(result, CacheResult::Miss)); + + // Set value + cache.set("key", b"value").await.unwrap(); + + // Hit + let result = cache.get("key").await; + match result { + CacheResult::Hit(data) => assert_eq!(data, b"value"), + _ => panic!("Expected hit"), + } + } + + #[tokio::test] + async fn test_cache_invalidate() { + let config = QueryCacheConfig { enabled: true, ttl_seconds: 60 }; + let cache = InMemoryQueryCache::new(config); + + cache.set("key1", b"value1").await.unwrap(); + cache.set("key2", b"value2").await.unwrap(); + + // Invalidate single key + cache.invalidate("key1").await.unwrap(); + + let result = cache.get("key1").await; + assert!(matches!(result, CacheResult::Miss)); + + let result = cache.get("key2").await; + assert!(matches!(result, CacheResult::Hit(_))); + } + + #[tokio::test] + async fn test_cache_invalidate_pattern() { + let config = QueryCacheConfig { enabled: true, ttl_seconds: 60 }; + let cache = InMemoryQueryCache::new(config); + + cache.set("torii:query:entities:abc", b"value1").await.unwrap(); + cache.set("torii:query:entities:def", b"value2").await.unwrap(); + cache.set("torii:query:tokens:xyz", b"value3").await.unwrap(); + + // Invalidate entities pattern + cache.invalidate_pattern("torii:query:entities:*").await.unwrap(); + + let result = cache.get("torii:query:entities:abc").await; + assert!(matches!(result, CacheResult::Miss)); + + let result = cache.get("torii:query:entities:def").await; + assert!(matches!(result, CacheResult::Miss)); + + let result = cache.get("torii:query:tokens:xyz").await; + assert!(matches!(result, CacheResult::Hit(_))); + } +} diff --git a/crates/cache/src/query_cache/key.rs b/crates/cache/src/query_cache/key.rs new file mode 100644 index 00000000..fcab9898 --- /dev/null +++ b/crates/cache/src/query_cache/key.rs @@ -0,0 +1,102 @@ +use sha2::{Digest, Sha256}; + +/// Generate a cache key from SQL query and bind parameters. +/// +/// The key format is: `torii:query:{table}:{hash}` +/// where hash is a SHA-256 hash of the SQL and bind parameters. +pub fn generate_cache_key(sql: &str, binds: &[String]) -> String { + let mut hasher = Sha256::new(); + hasher.update(sql.as_bytes()); + for bind in binds { + hasher.update(b"|"); // Separator + hasher.update(bind.as_bytes()); + } + let hash = hasher.finalize(); + + // Extract table name for pattern invalidation + let table = extract_table_name(sql).unwrap_or("unknown"); + format!("torii:query:{}:{:x}", table, hash) +} + +/// Extract the primary table name from a SQL query. +/// +/// Handles both bracketed `[table_name]` and plain table names. +fn extract_table_name(sql: &str) -> Option<&str> { + let sql_upper = sql.to_uppercase(); + let from_idx = sql_upper.find("FROM ")?; + let after_from = &sql[from_idx + 5..]; + let trimmed = after_from.trim_start(); + + // Handle [bracketed] or plain table names + if trimmed.starts_with('[') { + let end = trimmed.find(']')?; + Some(&trimmed[1..end]) + } else { + let end = trimmed + .find(|c: char| c.is_whitespace() || c == ',' || c == ')' || c == ';') + .unwrap_or(trimmed.len()); + if end == 0 { + None + } else { + Some(&trimmed[..end]) + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + + #[test] + fn test_extract_table_name_simple() { + let sql = "SELECT * FROM entities WHERE id = ?"; + assert_eq!(extract_table_name(sql), Some("entities")); + } + + #[test] + fn test_extract_table_name_bracketed() { + let sql = "SELECT * FROM [entity_model] WHERE id = ?"; + assert_eq!(extract_table_name(sql), Some("entity_model")); + } + + #[test] + fn test_extract_table_name_with_join() { + let sql = "SELECT e.* FROM entities e JOIN models m ON e.id = m.entity_id"; + assert_eq!(extract_table_name(sql), Some("entities")); + } + + #[test] + fn test_extract_table_name_lowercase() { + let sql = "select * from tokens where id = ?"; + assert_eq!(extract_table_name(sql), Some("tokens")); + } + + #[test] + fn test_generate_cache_key() { + let sql = "SELECT * FROM entities WHERE id = ?"; + let binds = vec!["123".to_string()]; + let key = generate_cache_key(sql, &binds); + + assert!(key.starts_with("torii:query:entities:")); + assert!(key.len() > 30); // Includes hash + } + + #[test] + fn test_generate_cache_key_different_binds() { + let sql = "SELECT * FROM entities WHERE id = ?"; + let key1 = generate_cache_key(sql, &["123".to_string()]); + let key2 = generate_cache_key(sql, &["456".to_string()]); + + assert_ne!(key1, key2); + } + + #[test] + fn test_generate_cache_key_same_query() { + let sql = "SELECT * FROM entities WHERE id = ?"; + let binds = vec!["123".to_string()]; + let key1 = generate_cache_key(sql, &binds); + let key2 = generate_cache_key(sql, &binds); + + assert_eq!(key1, key2); + } +} diff --git a/crates/cache/src/query_cache/mod.rs b/crates/cache/src/query_cache/mod.rs new file mode 100644 index 00000000..8b638582 --- /dev/null +++ b/crates/cache/src/query_cache/mod.rs @@ -0,0 +1,75 @@ +use async_trait::async_trait; + +pub mod cached_row; +pub mod in_memory; +pub mod key; +#[cfg(feature = "redis")] +pub mod redis; +#[cfg(feature = "redis")] +pub mod tiered; + +pub use cached_row::{CachedColumn, CachedPage, CachedRow, CachedValue}; +pub use in_memory::InMemoryQueryCache; +pub use key::generate_cache_key; + +/// Configuration for query caching. +#[derive(Debug, Clone)] +pub struct QueryCacheConfig { + /// Whether caching is enabled. + pub enabled: bool, + /// Time-to-live for cache entries in seconds. + pub ttl_seconds: u64, +} + +impl Default for QueryCacheConfig { + fn default() -> Self { + Self { enabled: false, ttl_seconds: 60 } + } +} + +/// Result type for cache lookups. +#[derive(Debug)] +pub enum CacheResult { + /// Cache hit - returns the cached data. + Hit(T), + /// Cache miss - data not found. + Miss, + /// Cache is disabled. + Disabled, +} + +/// Cache error type. +#[derive(Debug, thiserror::Error)] +pub enum QueryCacheError { + #[error("Serialization error: {0}")] + Serialization(String), + #[error("Cache backend error: {0}")] + Backend(String), + #[cfg(feature = "redis")] + #[error("Redis error: {0}")] + Redis(#[from] ::redis::RedisError), +} + +/// Trait defining the query cache interface. +/// +/// Implementors must be thread-safe (Send + Sync) and debuggable. +#[async_trait] +pub trait QueryCache: Send + Sync + std::fmt::Debug { + /// Get a cached value by key. + async fn get(&self, key: &str) -> CacheResult>; + + /// Set a cached value. + async fn set(&self, key: &str, value: &[u8]) -> Result<(), QueryCacheError>; + + /// Invalidate a specific cache key. + async fn invalidate(&self, key: &str) -> Result<(), QueryCacheError>; + + /// Invalidate all cache keys matching a pattern (e.g., "torii:query:entities:*"). + async fn invalidate_pattern(&self, pattern: &str) -> Result<(), QueryCacheError>; + + /// Check if the cache backend is currently available. + fn is_available(&self) -> bool; + + /// Get the cache configuration. + fn config(&self) -> &QueryCacheConfig; +} diff --git a/crates/cache/src/query_cache/redis.rs b/crates/cache/src/query_cache/redis.rs new file mode 100644 index 00000000..9e84f655 --- /dev/null +++ b/crates/cache/src/query_cache/redis.rs @@ -0,0 +1,162 @@ +use async_trait::async_trait; +use redis::{aio::MultiplexedConnection, AsyncCommands, Client}; +use std::sync::atomic::{AtomicBool, Ordering}; +use tokio::sync::RwLock; + +use super::{CacheResult, QueryCache, QueryCacheConfig, QueryCacheError}; + +/// Redis-backed query cache implementation. +#[derive(Debug)] +pub struct RedisQueryCache { + config: QueryCacheConfig, + client: Client, + connection: RwLock>, + available: AtomicBool, +} + +impl RedisQueryCache { + /// Create a new Redis cache and verify connectivity. + pub async fn new(url: &str, config: QueryCacheConfig) -> Result { + let client = Client::open(url)?; + + // Test connection + let mut conn = client.get_multiplexed_async_connection().await?; + let _: () = redis::cmd("PING").query_async(&mut conn).await?; + + Ok(Self { + config, + client, + connection: RwLock::new(Some(conn)), + available: AtomicBool::new(true), + }) + } + + /// Get or create a connection, handling reconnection on failure. + async fn get_connection(&self) -> Option { + // Try to reuse existing connection + { + let conn = self.connection.read().await; + if conn.is_some() { + return conn.clone(); + } + } + + // Try to reconnect + match self.client.get_multiplexed_async_connection().await { + Ok(conn) => { + let mut lock = self.connection.write().await; + *lock = Some(conn.clone()); + self.available.store(true, Ordering::Relaxed); + Some(conn) + } + Err(_) => { + self.available.store(false, Ordering::Relaxed); + None + } + } + } + + /// Mark connection as failed and clear it. + async fn mark_connection_failed(&self) { + let mut lock = self.connection.write().await; + *lock = None; + self.available.store(false, Ordering::Relaxed); + } +} + +#[async_trait] +impl QueryCache for RedisQueryCache { + async fn get(&self, key: &str) -> CacheResult> { + if !self.config.enabled { + return CacheResult::Disabled; + } + + let mut conn = match self.get_connection().await { + Some(c) => c, + None => return CacheResult::Miss, + }; + + match conn.get::<_, Option>>(key).await { + Ok(Some(data)) => { + self.available.store(true, Ordering::Relaxed); + CacheResult::Hit(data) + } + Ok(None) => CacheResult::Miss, + Err(_) => { + self.mark_connection_failed().await; + CacheResult::Miss + } + } + } + + async fn set(&self, key: &str, value: &[u8]) -> Result<(), QueryCacheError> { + if !self.config.enabled { + return Ok(()); + } + + let mut conn = match self.get_connection().await { + Some(c) => c, + None => return Err(QueryCacheError::Backend("Redis unavailable".to_string())), + }; + + match conn.set_ex::<_, _, ()>(key, value, self.config.ttl_seconds).await { + Ok(_) => { + self.available.store(true, Ordering::Relaxed); + Ok(()) + } + Err(e) => { + self.mark_connection_failed().await; + Err(QueryCacheError::Redis(e)) + } + } + } + + async fn invalidate(&self, key: &str) -> Result<(), QueryCacheError> { + let mut conn = match self.get_connection().await { + Some(c) => c, + None => return Err(QueryCacheError::Backend("Redis unavailable".to_string())), + }; + + match conn.del::<_, ()>(key).await { + Ok(_) => Ok(()), + Err(e) => { + self.mark_connection_failed().await; + Err(QueryCacheError::Redis(e)) + } + } + } + + async fn invalidate_pattern(&self, pattern: &str) -> Result<(), QueryCacheError> { + let mut conn = match self.get_connection().await { + Some(c) => c, + None => return Err(QueryCacheError::Backend("Redis unavailable".to_string())), + }; + + // Use SCAN for production (KEYS can block), but KEYS is simpler for now + // In production, consider using SCAN with iteration + let keys: Vec = match redis::cmd("KEYS").arg(pattern).query_async(&mut conn).await { + Ok(k) => k, + Err(e) => { + self.mark_connection_failed().await; + return Err(QueryCacheError::Redis(e)); + } + }; + + if !keys.is_empty() { + if let Err(e) = conn.del::<_, ()>(keys).await { + self.mark_connection_failed().await; + return Err(QueryCacheError::Redis(e)); + } + } + + Ok(()) + } + + fn is_available(&self) -> bool { + self.available.load(Ordering::Relaxed) + } + + fn config(&self) -> &QueryCacheConfig { + &self.config + } +} diff --git a/crates/cache/src/query_cache/tiered.rs b/crates/cache/src/query_cache/tiered.rs new file mode 100644 index 00000000..1e2dac2e --- /dev/null +++ b/crates/cache/src/query_cache/tiered.rs @@ -0,0 +1,105 @@ +use async_trait::async_trait; +use std::sync::Arc; + +use super::{CacheResult, QueryCache, QueryCacheConfig, QueryCacheError}; + +/// A tiered cache that tries a primary cache (Redis) first, then falls back to +/// a secondary cache (in-memory). +#[derive(Debug)] +pub struct TieredQueryCache { + /// Primary cache (typically Redis). + primary: Arc, + /// Fallback cache (typically in-memory). + fallback: Arc, +} + +impl TieredQueryCache { + /// Create a new tiered cache with primary and fallback backends. + pub fn new(primary: Arc, fallback: Arc) -> Self { + Self { primary, fallback } + } +} + +#[async_trait] +impl QueryCache for TieredQueryCache { + async fn get(&self, key: &str) -> CacheResult> { + // Try primary (Redis) first if available + if self.primary.is_available() { + if let CacheResult::Hit(data) = self.primary.get(key).await { + return CacheResult::Hit(data); + } + } + // Fallback to secondary (in-memory) + self.fallback.get(key).await + } + + async fn set(&self, key: &str, value: &[u8]) -> Result<(), QueryCacheError> { + // Write to both caches + // Primary write failure is non-fatal if fallback succeeds + if self.primary.is_available() { + let _ = self.primary.set(key, value).await; + } + self.fallback.set(key, value).await + } + + async fn invalidate(&self, key: &str) -> Result<(), QueryCacheError> { + // Invalidate in both caches + if self.primary.is_available() { + let _ = self.primary.invalidate(key).await; + } + self.fallback.invalidate(key).await + } + + async fn invalidate_pattern(&self, pattern: &str) -> Result<(), QueryCacheError> { + // Invalidate pattern in both caches + if self.primary.is_available() { + let _ = self.primary.invalidate_pattern(pattern).await; + } + self.fallback.invalidate_pattern(pattern).await + } + + fn is_available(&self) -> bool { + // Available if either cache is available + self.primary.is_available() || self.fallback.is_available() + } + + fn config(&self) -> &QueryCacheConfig { + // Return fallback config (always available) + self.fallback.config() + } +} + +#[cfg(test)] +mod tests { + use super::super::in_memory::InMemoryQueryCache; + use super::*; + + #[tokio::test] + async fn test_tiered_fallback_when_primary_unavailable() { + let config = QueryCacheConfig { enabled: true, ttl_seconds: 60 }; + + // Create two in-memory caches to simulate tiered behavior + // (In real usage, primary would be Redis) + let primary = Arc::new(InMemoryQueryCache::new(config.clone())); + let fallback = Arc::new(InMemoryQueryCache::new(config.clone())); + + let tiered = TieredQueryCache::new(primary.clone(), fallback.clone()); + + // Set via tiered (writes to both) + tiered.set("key", b"value").await.unwrap(); + + // Both should have the value + let result = primary.get("key").await; + assert!(matches!(result, CacheResult::Hit(_))); + + let result = fallback.get("key").await; + assert!(matches!(result, CacheResult::Hit(_))); + + // Get via tiered should hit primary first + let result = tiered.get("key").await; + match result { + CacheResult::Hit(data) => assert_eq!(data, b"value"), + _ => panic!("Expected hit"), + } + } +} diff --git a/crates/cli/src/args.rs b/crates/cli/src/args.rs index bb9a3700..54a3894f 100644 --- a/crates/cli/src/args.rs +++ b/crates/cli/src/args.rs @@ -62,6 +62,10 @@ pub struct ToriiArgs { #[merge] pub sql: SqlOptions, + #[command(flatten)] + #[merge] + pub query_cache: QueryCacheOptions, + #[command(flatten)] #[merge] pub activity: ActivityOptions, @@ -117,6 +121,7 @@ impl Default for ToriiArgs { events: EventsOptions::default(), erc: ErcOptions::default(), sql: SqlOptions::default(), + query_cache: QueryCacheOptions::default(), activity: ActivityOptions::default(), achievement: AchievementOptions::default(), snapshot: SnapshotOptions::default(), diff --git a/crates/cli/src/options.rs b/crates/cli/src/options.rs index e2fac63a..d1f61c93 100644 --- a/crates/cli/src/options.rs +++ b/crates/cli/src/options.rs @@ -45,6 +45,10 @@ pub const DEFAULT_DATABASE_MAX_CONNECTIONS: u32 = 100; pub const DEFAULT_MESSAGING_MAX_AGE: u64 = 300_000; pub const DEFAULT_MESSAGING_FUTURE_TOLERANCE: u64 = 60_000; +// Query cache defaults +/// Default TTL for query cache entries in seconds (60 seconds) +pub const DEFAULT_QUERY_CACHE_TTL: u64 = 60; + // Activity tracking defaults /// Default session timeout in seconds (1 hour) pub const DEFAULT_ACTIVITY_SESSION_TIMEOUT: u64 = 3600; @@ -401,6 +405,39 @@ impl Default for MetricsOptions { } } +#[derive(Debug, clap::Args, Clone, Serialize, Deserialize, PartialEq, MergeOptions)] +#[serde(default)] +#[command(next_help_heading = "Query cache options")] +pub struct QueryCacheOptions { + /// Enable query result caching for improved read performance. + /// + /// When enabled, SELECT query results are cached in memory (and optionally Redis). + /// This can significantly improve read performance for frequently accessed data. + #[arg(long = "cache.enabled", default_value_t = false)] + pub cache_enabled: bool, + + /// Cache TTL (time-to-live) in seconds. + /// + /// Cached query results will be automatically invalidated after this duration. + /// Additionally, caches are invalidated immediately when data is modified. + #[arg(long = "cache.ttl", default_value_t = DEFAULT_QUERY_CACHE_TTL)] + pub cache_ttl: u64, + + /// Redis URL for distributed caching. + /// + /// When provided, query results are cached in Redis in addition to in-memory cache. + /// This enables cache sharing across multiple Torii instances. + /// Format: redis://[username:password@]host[:port][/database] + #[arg(long = "cache.redis_url", value_name = "URL")] + pub redis_url: Option, +} + +impl Default for QueryCacheOptions { + fn default() -> Self { + Self { cache_enabled: false, cache_ttl: DEFAULT_QUERY_CACHE_TTL, redis_url: None } + } +} + #[derive(Debug, clap::Args, Clone, Serialize, Deserialize, PartialEq, MergeOptions)] #[serde(default)] #[command(next_help_heading = "ERC options")] diff --git a/crates/grpc/server/src/lib.rs b/crates/grpc/server/src/lib.rs index 6ffaef25..e201ed05 100644 --- a/crates/grpc/server/src/lib.rs +++ b/crates/grpc/server/src/lib.rs @@ -45,7 +45,7 @@ use crate::subscriptions::transaction::TransactionManager; use self::subscriptions::entity::EntityManager; use self::subscriptions::event_message::EventMessageManager; -use sqlx::SqlitePool; +use torii_sqlite::caching_pool::{CachedQueryResult, CachingPool}; use torii_proto::proto::world::world_server::WorldServer; use torii_proto::proto::world::{ PublishMessageBatchRequest, PublishMessageBatchResponse, PublishMessageRequest, @@ -94,7 +94,7 @@ pub struct DojoWorld { aggregation_manager: Arc, activity_manager: Arc, achievement_progression_manager: Arc, - pool: SqlitePool, + pool: CachingPool, _config: GrpcConfig, } @@ -103,7 +103,7 @@ impl DojoWorld

{ storage: Arc, messaging: Arc>, cross_messaging_tx: Option>, - pool: SqlitePool, + pool: CachingPool, config: GrpcConfig, ) -> Self { let entity_manager = Arc::new(EntityManager::new(config.clone())); @@ -1185,17 +1185,22 @@ impl proto::world::world_server::World for ) -> Result, Status> { let proto::types::SqlQueryRequest { query } = request.into_inner(); - // Execute the query - let rows = sqlx::query(&query) - .fetch_all(&self.pool) + let result = self + .pool + .fetch_all_cached(&query, &[]) .await .map_err(|e| Status::invalid_argument(format!("Query error: {:?}", e)))?; - // Map rows to proto types - let proto_rows: Vec = rows - .iter() - .map(torii_sqlite::utils::map_row_to_proto) - .collect(); + let proto_rows: Vec = match result { + CachedQueryResult::Cached(rows) => rows + .iter() + .map(torii_sqlite::utils::map_cached_row_to_proto) + .collect(), + CachedQueryResult::Fresh(rows) => rows + .iter() + .map(torii_sqlite::utils::map_row_to_proto) + .collect(), + }; Ok(Response::new(proto::types::SqlQueryResponse { rows: proto_rows, @@ -1248,7 +1253,7 @@ pub async fn new( storage: Arc, messaging: Arc>, cross_messaging_tx: UnboundedSender, - pool: SqlitePool, + pool: CachingPool, config: GrpcConfig, bind_addr: Option, ) -> Result< diff --git a/crates/grpc/server/src/tests/entities_test.rs b/crates/grpc/server/src/tests/entities_test.rs index 7a9ba5ad..c53d5ebe 100644 --- a/crates/grpc/server/src/tests/entities_test.rs +++ b/crates/grpc/server/src/tests/entities_test.rs @@ -32,6 +32,7 @@ use torii_proto::proto::world::RetrieveEntitiesRequest; use torii_proto::{Clause, KeysClause, PatternMatching, Query}; use torii_proto::schema::Entity; +use torii_sqlite::caching_pool::CachingPool; use torii_sqlite::executor::Executor; use torii_sqlite::{Sql, SqlConfig}; use torii_storage::proto::{ContractDefinition, ContractType}; @@ -169,7 +170,7 @@ async fn test_entities_queries(sequencer: &RunnerCtx) { storage, messaging.clone(), None, - pool.clone(), + CachingPool::new(pool.clone()), GrpcConfig::default(), ); @@ -329,7 +330,7 @@ async fn test_keys_clause_with_empty_models(sequencer: &RunnerCtx) { storage, messaging.clone(), None, - pool.clone(), + CachingPool::new(pool.clone()), GrpcConfig::default(), ); @@ -490,7 +491,7 @@ async fn test_keys_clause_with_specific_models(sequencer: &RunnerCtx) { storage, messaging.clone(), None, - pool.clone(), + CachingPool::new(pool.clone()), GrpcConfig::default(), ); @@ -651,7 +652,7 @@ async fn test_hashed_keys_clause(sequencer: &RunnerCtx) { storage, messaging.clone(), None, - pool.clone(), + CachingPool::new(pool.clone()), GrpcConfig::default(), ); @@ -808,7 +809,7 @@ async fn test_member_clause(sequencer: &RunnerCtx) { storage, messaging.clone(), None, - pool.clone(), + CachingPool::new(pool.clone()), GrpcConfig::default(), ); @@ -971,7 +972,7 @@ async fn test_composite_clause_and(sequencer: &RunnerCtx) { storage, messaging.clone(), None, - pool.clone(), + CachingPool::new(pool.clone()), GrpcConfig::default(), ); @@ -1164,7 +1165,7 @@ async fn test_composite_clause_or(sequencer: &RunnerCtx) { storage, messaging.clone(), None, - pool.clone(), + CachingPool::new(pool.clone()), GrpcConfig::default(), ); @@ -1347,7 +1348,7 @@ async fn test_historical_query(sequencer: &RunnerCtx) { storage, messaging.clone(), None, - pool.clone(), + CachingPool::new(pool.clone()), GrpcConfig::default(), ); diff --git a/crates/grpc/server/src/tests/messaging.rs b/crates/grpc/server/src/tests/messaging.rs index e1ab9d09..d1e8fe85 100644 --- a/crates/grpc/server/src/tests/messaging.rs +++ b/crates/grpc/server/src/tests/messaging.rs @@ -20,6 +20,7 @@ use tonic::Request; use torii_libp2p_relay::Relay; use torii_messaging::{Messaging, MessagingConfig}; use torii_proto::proto::world::PublishMessageRequest; +use torii_sqlite::caching_pool::CachingPool; use torii_sqlite::executor::Executor; use torii_sqlite::Sql; use torii_storage::proto::{ContractDefinition, ContractType}; @@ -115,7 +116,7 @@ async fn test_publish_message(sequencer: &RunnerCtx) { db.clone(), messaging, None, - pool.clone(), + CachingPool::new(pool.clone()), GrpcConfig::default(), ); @@ -437,7 +438,7 @@ async fn test_cross_messaging_between_relay_servers(sequencer: &RunnerCtx) { db1, messaging1, Some(cross_messaging_tx1), - pool1.clone(), + CachingPool::new(pool1.clone()), GrpcConfig::default(), ); @@ -636,7 +637,7 @@ async fn test_publish_message_with_bad_signature_fails(sequencer: &RunnerCtx) { db.clone(), messaging, None, - pool.clone(), + CachingPool::new(pool.clone()), GrpcConfig::default(), ); @@ -831,7 +832,7 @@ async fn test_timestamp_validation_logic(sequencer: &RunnerCtx) { db.clone(), messaging, None, - pool.clone(), + CachingPool::new(pool.clone()), GrpcConfig::default(), ); diff --git a/crates/runner/Cargo.toml b/crates/runner/Cargo.toml index 806e92a6..957c7d10 100644 --- a/crates/runner/Cargo.toml +++ b/crates/runner/Cargo.toml @@ -72,3 +72,4 @@ camino.workspace = true default = ["jemalloc", "sqlite"] jemalloc = ["dojo-metrics/jemalloc"] sqlite = ["sqlx/sqlite"] +redis = ["torii-cache/redis"] diff --git a/crates/runner/src/lib.rs b/crates/runner/src/lib.rs index 7f2ceb64..73409a5f 100644 --- a/crates/runner/src/lib.rs +++ b/crates/runner/src/lib.rs @@ -41,6 +41,9 @@ use tokio::sync::broadcast::Sender; use tokio_stream::StreamExt; use torii_broker::types::ModelUpdate; use torii_broker::MemoryBroker; +use torii_cache::query_cache::{InMemoryQueryCache, QueryCache, QueryCacheConfig}; +#[cfg(feature = "redis")] +use torii_cache::query_cache::{RedisQueryCache, TieredQueryCache}; use torii_cache::InMemoryCache; use torii_cli::ToriiArgs; use torii_controllers::sync::ControllersSync; @@ -51,6 +54,7 @@ use torii_libp2p_relay::Relay; use torii_messaging::{Messaging, MessagingConfig}; use torii_processors::{EventProcessorConfig, Processors}; use torii_server::proxy::{Proxy, ProxySettings}; +use torii_sqlite::caching_pool::CachingPool; use torii_sqlite::executor::Executor; use torii_sqlite::{Sql, SqlConfig}; use torii_storage::proto::{ContractDefinition, ContractType}; @@ -570,12 +574,56 @@ impl Runner { search_snippet_length: self.args.search.snippet_length, }; + let query_cache: Option> = if self.args.query_cache.cache_enabled { + let query_cache_config = QueryCacheConfig { + enabled: true, + ttl_seconds: self.args.query_cache.cache_ttl, + }; + + #[cfg(feature = "redis")] + { + if let Some(url) = &self.args.query_cache.redis_url { + match RedisQueryCache::new(url, query_cache_config.clone()).await { + Ok(redis_cache) => { + info!(target: LOG_TARGET, redis_url = %url, ttl = self.args.query_cache.cache_ttl, "Query cache enabled with Redis"); + Some(Arc::new(TieredQueryCache::new( + Arc::new(redis_cache), + Arc::new(InMemoryQueryCache::new(query_cache_config)), + ))) + } + Err(error) => { + warn!(target: LOG_TARGET, error = %error, "Failed to connect to Redis, falling back to in-memory query cache"); + Some(Arc::new(InMemoryQueryCache::new(query_cache_config))) + } + } + } else { + info!(target: LOG_TARGET, ttl = self.args.query_cache.cache_ttl, "Query cache enabled with in-memory backend"); + Some(Arc::new(InMemoryQueryCache::new(query_cache_config))) + } + } + + #[cfg(not(feature = "redis"))] + { + if self.args.query_cache.redis_url.is_some() { + warn!( + target: LOG_TARGET, + "Redis URL provided but torii was built without the redis feature; using in-memory query cache" + ); + } + info!(target: LOG_TARGET, ttl = self.args.query_cache.cache_ttl, "Query cache enabled with in-memory backend"); + Some(Arc::new(InMemoryQueryCache::new(query_cache_config))) + } + } else { + None + }; + let (mut executor, sender) = Executor::new_with_config( write_pool.clone(), shutdown_tx.clone(), provider.clone(), sql_config.clone(), database_path.clone(), + query_cache.clone(), ) .await?; let executor_handle = tokio::spawn(async move { executor.run().await }); @@ -590,6 +638,12 @@ impl Runner { let cache = Arc::new(InMemoryCache::new(Arc::new(db.clone())).await.unwrap()); let db = db.with_cache(cache.clone()); + let db = if let Some(cache) = query_cache.clone() { + db.with_query_cache(cache) + } else { + db + }; + let processors = Arc::new(Processors::default()); let mut indexing_flags = IndexingFlags::empty(); @@ -711,6 +765,12 @@ impl Runner { provider.clone(), )); + let caching_pool = if let Some(cache) = query_cache.clone() { + CachingPool::new(readonly_pool.clone()).with_cache(cache) + } else { + CachingPool::new(readonly_pool.clone()) + }; + let (mut libp2p_relay_server, cross_messaging_tx) = Relay::new_with_peers( messaging.clone(), self.args.relay.port, @@ -728,7 +788,7 @@ impl Runner { storage.clone(), messaging.clone(), cross_messaging_tx, - readonly_pool.clone(), + caching_pool.clone(), GrpcConfig { subscription_buffer_size: self.args.grpc.subscription_buffer_size, optimistic: self.args.grpc.optimistic, diff --git a/crates/sqlite/sqlite/Cargo.toml b/crates/sqlite/sqlite/Cargo.toml index 6ad619ae..2fc3858b 100644 --- a/crates/sqlite/sqlite/Cargo.toml +++ b/crates/sqlite/sqlite/Cargo.toml @@ -13,6 +13,7 @@ torii-broker.workspace = true torii-sqlite-types.workspace = true anyhow.workspace = true +bincode.workspace = true async-trait.workspace = true base64.workspace = true bitflags = "2.6.0" diff --git a/crates/sqlite/sqlite/src/caching_pool.rs b/crates/sqlite/sqlite/src/caching_pool.rs new file mode 100644 index 00000000..311a5e6f --- /dev/null +++ b/crates/sqlite/sqlite/src/caching_pool.rs @@ -0,0 +1,356 @@ +//! CachingPool - A wrapper around Pool that provides caching for SQL query results. +//! +//! This module provides a transparent cache layer for SQLite queries that: +//! - Caches SELECT query results using configurable backends (in-memory, Redis) +//! - Generates cache keys from SQL + bind parameters +//! - Supports pattern-based invalidation for write operations +//! - Falls back to direct pool access when caching is disabled + +use sqlx::sqlite::SqliteRow; +use sqlx::{Pool, Row, Sqlite}; +use std::sync::Arc; +use torii_cache::query_cache::{ + generate_cache_key, CacheResult, CachedRow, QueryCache, QueryCacheConfig, +}; +use tracing::{debug, trace, warn}; + +use crate::error::Error; + +/// Wrapper around Pool that provides transparent query caching. +#[derive(Debug, Clone)] +pub struct CachingPool { + inner: Pool, + cache: Option>, +} + +impl CachingPool { + /// Create a new CachingPool without caching (pass-through mode). + pub fn new(pool: Pool) -> Self { + Self { inner: pool, cache: None } + } + + /// Add a cache backend to the pool. + pub fn with_cache(mut self, cache: Arc) -> Self { + self.cache = Some(cache); + self + } + + /// Get the inner pool for direct access (writes, schema operations, etc.). + pub fn inner(&self) -> &Pool { + &self.inner + } + + /// Check if caching is enabled and available. + pub fn is_cache_enabled(&self) -> bool { + self.cache.as_ref().map(|c| c.config().enabled && c.is_available()).unwrap_or(false) + } + + /// Get cache configuration, if available. + pub fn cache_config(&self) -> Option<&QueryCacheConfig> { + self.cache.as_ref().map(|c| c.config()) + } + + /// Execute a SELECT query with caching. + /// + /// This method: + /// 1. Generates a cache key from the SQL and bind parameters + /// 2. Checks the cache for a hit + /// 3. On miss, executes the query against the database + /// 4. Caches the result for future requests + /// + /// # Arguments + /// * `sql` - The SQL query string + /// * `binds` - The bind parameter values as strings + /// + /// # Returns + /// Cached rows that can be converted to the desired type. + pub async fn fetch_all_cached( + &self, + sql: &str, + binds: &[String], + ) -> Result { + let cache_key = generate_cache_key(sql, binds); + + // Try cache first + if let Some(cache) = &self.cache { + match cache.get(&cache_key).await { + CacheResult::Hit(data) => { + trace!(target: "torii::sqlite::caching_pool", key = %cache_key, "Cache hit"); + if let Ok(cached_rows) = bincode::deserialize::>(&data) { + return Ok(CachedQueryResult::Cached(cached_rows)); + } + // Deserialization failed, fall through to database + warn!( + target: "torii::sqlite::caching_pool", + key = %cache_key, + "Cache deserialization failed, executing query" + ); + } + CacheResult::Miss => { + trace!(target: "torii::sqlite::caching_pool", key = %cache_key, "Cache miss"); + } + CacheResult::Disabled => { + // Caching disabled, just execute the query + } + } + } + + // Execute query against database + let rows = self.execute_query(sql, binds).await?; + + // Cache the result + if let Some(cache) = &self.cache { + let cached_rows: Vec = + rows.iter().map(CachedRow::from_sqlite_row).collect(); + if let Ok(data) = bincode::serialize(&cached_rows) { + if let Err(e) = cache.set(&cache_key, &data).await { + debug!( + target: "torii::sqlite::caching_pool", + key = %cache_key, + error = %e, + "Failed to cache query result" + ); + } + } + } + + Ok(CachedQueryResult::Fresh(rows)) + } + + /// Execute a query directly against the database. + async fn execute_query(&self, sql: &str, binds: &[String]) -> Result, Error> { + let mut query = sqlx::query(sql); + for bind in binds { + query = query.bind(bind); + } + Ok(query.fetch_all(&self.inner).await?) + } + + /// Invalidate all cached queries for a specific table. + /// + /// This should be called after write operations that modify a table. + pub async fn invalidate_table(&self, table: &str) { + if let Some(cache) = &self.cache { + let pattern = format!("torii:query:{}:*", table); + if let Err(e) = cache.invalidate_pattern(&pattern).await { + warn!( + target: "torii::sqlite::caching_pool", + table = %table, + error = %e, + "Failed to invalidate cache pattern" + ); + } + } + } + + /// Invalidate multiple tables at once. + pub async fn invalidate_tables(&self, tables: &[&str]) { + for table in tables { + self.invalidate_table(table).await; + } + } + + /// Invalidate a specific cache key. + pub async fn invalidate_key(&self, key: &str) { + if let Some(cache) = &self.cache { + if let Err(e) = cache.invalidate(key).await { + debug!( + target: "torii::sqlite::caching_pool", + key = %key, + error = %e, + "Failed to invalidate cache key" + ); + } + } + } +} + +/// Result of a cached query execution. +pub enum CachedQueryResult { + /// Result came from cache. + Cached(Vec), + /// Result came from database (fresh). + Fresh(Vec), +} + +impl std::fmt::Debug for CachedQueryResult { + fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result { + match self { + Self::Cached(rows) => f.debug_tuple("Cached").field(&rows.len()).finish(), + Self::Fresh(rows) => f.debug_tuple("Fresh").field(&rows.len()).finish(), + } + } +} + +impl CachedQueryResult { + /// Get the number of rows. + pub fn len(&self) -> usize { + match self { + Self::Cached(rows) => rows.len(), + Self::Fresh(rows) => rows.len(), + } + } + + /// Check if result is empty. + pub fn is_empty(&self) -> bool { + self.len() == 0 + } + + /// Check if result came from cache. + pub fn is_cached(&self) -> bool { + matches!(self, Self::Cached(_)) + } + + /// Convert to CachedRows (either from cache or by converting fresh rows). + pub fn into_cached_rows(self) -> Vec { + match self { + Self::Cached(rows) => rows, + Self::Fresh(rows) => rows.iter().map(CachedRow::from_sqlite_row).collect(), + } + } + + /// Get a value by column index from the first row. + pub fn get_first_i64(&self, idx: usize) -> Option { + match self { + Self::Cached(rows) => { + rows.first().and_then(|row| row.values.get(idx)).and_then(|v| v.as_i64()) + } + Self::Fresh(rows) => rows.first().and_then(|row| row.try_get::(idx).ok()), + } + } + + /// Get a string value by column index from the first row. + pub fn get_first_string(&self, idx: usize) -> Option { + match self { + Self::Cached(rows) => rows + .first() + .and_then(|row| row.values.get(idx)) + .and_then(|v| v.as_str()) + .map(|s| s.to_string()), + Self::Fresh(rows) => rows.first().and_then(|row| row.try_get::(idx).ok()), + } + } +} + +#[cfg(test)] +mod tests { + use super::*; + use sqlx::sqlite::SqlitePoolOptions; + use torii_cache::query_cache::InMemoryQueryCache; + + async fn create_test_pool() -> Pool { + SqlitePoolOptions::new() + .max_connections(1) + .connect("sqlite::memory:") + .await + .unwrap() + } + + #[tokio::test] + async fn test_caching_pool_without_cache() { + let pool = create_test_pool().await; + + // Create table + sqlx::query("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") + .execute(&pool) + .await + .unwrap(); + + sqlx::query("INSERT INTO test (id, name) VALUES (1, 'test')") + .execute(&pool) + .await + .unwrap(); + + let caching_pool = CachingPool::new(pool); + + // Should work without cache + let result = caching_pool + .fetch_all_cached("SELECT * FROM test WHERE id = ?", &["1".to_string()]) + .await + .unwrap(); + + assert_eq!(result.len(), 1); + assert!(!result.is_cached()); + } + + #[tokio::test] + async fn test_caching_pool_with_cache() { + let pool = create_test_pool().await; + + // Create table + sqlx::query("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") + .execute(&pool) + .await + .unwrap(); + + sqlx::query("INSERT INTO test (id, name) VALUES (1, 'test')") + .execute(&pool) + .await + .unwrap(); + + let config = QueryCacheConfig { enabled: true, ttl_seconds: 60 }; + let cache = Arc::new(InMemoryQueryCache::new(config)); + let caching_pool = CachingPool::new(pool).with_cache(cache); + + // First query - should miss cache + let result1 = caching_pool + .fetch_all_cached("SELECT * FROM test WHERE id = ?", &["1".to_string()]) + .await + .unwrap(); + + assert_eq!(result1.len(), 1); + assert!(!result1.is_cached()); + + // Second query - should hit cache + let result2 = caching_pool + .fetch_all_cached("SELECT * FROM test WHERE id = ?", &["1".to_string()]) + .await + .unwrap(); + + assert_eq!(result2.len(), 1); + assert!(result2.is_cached()); + } + + #[tokio::test] + async fn test_cache_invalidation() { + let pool = create_test_pool().await; + + // Create table + sqlx::query("CREATE TABLE test (id INTEGER PRIMARY KEY, name TEXT)") + .execute(&pool) + .await + .unwrap(); + + sqlx::query("INSERT INTO test (id, name) VALUES (1, 'test')") + .execute(&pool) + .await + .unwrap(); + + let config = QueryCacheConfig { enabled: true, ttl_seconds: 60 }; + let cache = Arc::new(InMemoryQueryCache::new(config)); + let caching_pool = CachingPool::new(pool).with_cache(cache); + + // First query - populate cache + let _ = caching_pool + .fetch_all_cached("SELECT * FROM test WHERE id = ?", &["1".to_string()]) + .await + .unwrap(); + + // Second query - hits cache + let result = caching_pool + .fetch_all_cached("SELECT * FROM test WHERE id = ?", &["1".to_string()]) + .await + .unwrap(); + assert!(result.is_cached()); + + // Invalidate the table + caching_pool.invalidate_table("test").await; + + // Third query - should miss cache + let result = caching_pool + .fetch_all_cached("SELECT * FROM test WHERE id = ?", &["1".to_string()]) + .await + .unwrap(); + assert!(!result.is_cached()); + } +} diff --git a/crates/sqlite/sqlite/src/cursor.rs b/crates/sqlite/sqlite/src/cursor.rs index 4176fb7e..5d155c86 100644 --- a/crates/sqlite/sqlite/src/cursor.rs +++ b/crates/sqlite/sqlite/src/cursor.rs @@ -5,6 +5,7 @@ use flate2::write::DeflateEncoder; use flate2::Compression; use sqlx::sqlite::SqliteRow; use sqlx::Row; +use torii_cache::query_cache::{CachedRow, CachedValue}; use std::io::prelude::*; use torii_proto::{OrderDirection, Pagination, PaginationDirection}; @@ -115,6 +116,36 @@ pub fn build_cursor_values(pagination: &Pagination, row: &SqliteRow) -> Result Result, Error> { + let mut values = Vec::new(); + for ob in &pagination.order_by { + let value = row.get(&ob.field).ok_or_else(|| { + Error::Query(QueryError::InvalidCursor(format!( + "Missing cursor column {}", + ob.field + ))) + })?; + + let string_value = match value { + CachedValue::Text(text) => text.clone(), + CachedValue::Integer(int_val) => int_val.to_string(), + CachedValue::Real(real_val) => real_val.to_string(), + CachedValue::Blob(_) | CachedValue::Null => { + return Err(Error::Query(QueryError::InvalidCursor(format!( + "Unsupported cursor value for column {}", + ob.field + )))) + } + }; + + values.push(string_value); + } + Ok(values) +} + /// Encodes cursor values into a single string using a safe delimiter pub fn encode_cursor_values(values: &[String]) -> Result { let joined_values = values.join(CURSOR_DELIMITER); diff --git a/crates/sqlite/sqlite/src/executor/mod.rs b/crates/sqlite/sqlite/src/executor/mod.rs index 24c0de8a..308545c1 100644 --- a/crates/sqlite/sqlite/src/executor/mod.rs +++ b/crates/sqlite/sqlite/src/executor/mod.rs @@ -1,6 +1,7 @@ use std::collections::{HashMap, HashSet}; use std::path::PathBuf; use std::str::FromStr; +use std::sync::Arc; use std::time::{SystemTime, UNIX_EPOCH}; use cainome::cairo_serde::{ByteArray, CairoSerde}; @@ -39,6 +40,7 @@ use crate::utils::{ }; use crate::SqlConfig; use torii_broker::MemoryBroker; +use torii_cache::query_cache::QueryCache; pub mod achievement; pub mod activity; @@ -192,6 +194,8 @@ pub struct Executor<'c, P: Provider + Sync + Send + Clone + 'static> { db_path: PathBuf, // Timestamp of last optimization last_optimization: Option, + // Optional query cache for invalidation after writes + query_cache: Option>, } #[derive(Debug)] @@ -274,6 +278,7 @@ impl Executor<'_, P> { provider, crate::SqlConfig::default(), PathBuf::from(""), + None, ) .await } @@ -284,6 +289,7 @@ impl Executor<'_, P> { provider: P, config: SqlConfig, db_path: PathBuf, + query_cache: Option>, ) -> Result<(Self, UnboundedSender)> { let (tx, rx) = unbounded_channel(); let transaction = pool.begin().await?; @@ -301,6 +307,7 @@ impl Executor<'_, P> { config, db_path, last_optimization: None, + query_cache, }, tx, )) @@ -338,6 +345,7 @@ impl Executor<'_, P> { async fn handle_query_message(&mut self, query_message: QueryMessage) -> QueryResult<()> { let start_time = Instant::now(); let query_type_str = format!("{}", query_message.query_type); + let query_type_for_invalidation = query_message.query_type.clone(); let tx = self.transaction.as_mut().unwrap(); @@ -1177,6 +1185,8 @@ impl Executor<'_, P> { } } + self.invalidate_query_cache(&query_type_for_invalidation).await; + // Record metrics let duration = start_time.elapsed(); histogram!( @@ -1194,6 +1204,29 @@ impl Executor<'_, P> { Ok(()) } + async fn invalidate_query_cache(&self, query_type: &QueryType) { + let cache = match &self.query_cache { + Some(cache) => cache, + None => return, + }; + + let tables: &[&str] = match query_type { + QueryType::SetEntity(_) | QueryType::DeleteEntity(_) => { + &["entities", "entities_historical"] + } + QueryType::ApplyBalanceDiff(_) => &["token_balances", "tokens"], + QueryType::StoreTokenTransfer => &["token_transfers"], + QueryType::RegisterNftToken(_) => &["tokens"], + QueryType::RegisterTokenContract(_) => &["token_contracts"], + _ => return, + }; + + for table in tables { + let pattern = format!("torii:query:{}:*", table); + let _ = cache.invalidate_pattern(&pattern).await; + } + } + async fn execute(&mut self) -> Result<()> { if let Some(transaction) = self.transaction.take() { transaction.commit().await?; diff --git a/crates/sqlite/sqlite/src/lib.rs b/crates/sqlite/sqlite/src/lib.rs index 862605a8..8f6d5591 100644 --- a/crates/sqlite/sqlite/src/lib.rs +++ b/crates/sqlite/sqlite/src/lib.rs @@ -6,6 +6,7 @@ use dojo_types::schema::Ty; use sqlx::{Pool, Sqlite}; use starknet::core::types::Felt; use tokio::sync::mpsc::UnboundedSender; +use torii_cache::query_cache::QueryCache; use torii_cache::Cache; use torii_proto::ContractDefinition; use torii_storage::Storage; @@ -15,7 +16,9 @@ use crate::executor::error::ExecutorQueryError; use crate::executor::{Argument, QueryMessage}; use crate::utils::utc_dt_string_from_timestamp; use torii_sqlite_types::{AggregatorConfig, Hook, ModelIndices}; +use crate::caching_pool::CachingPool; +pub mod caching_pool; pub mod constants; pub mod cursor; pub mod error; @@ -84,6 +87,7 @@ pub struct Sql { pub executor: UnboundedSender, pub config: SqlConfig, pub cache: Option>, + pub query_cache: Option>, } impl Sql { @@ -120,6 +124,7 @@ impl Sql { executor, config, cache: None, + query_cache: None, }; db.execute().await?; @@ -134,6 +139,19 @@ impl Sql { } } + pub fn with_query_cache(self, query_cache: Arc) -> Self { + Self { query_cache: Some(query_cache), ..self } + } + + pub fn caching_pool(&self) -> CachingPool { + let pool = CachingPool::new(self.pool.clone()); + if let Some(cache) = self.query_cache.clone() { + pool.with_cache(cache) + } else { + pool + } + } + fn set_entity_model( &self, model_name: &str, diff --git a/crates/sqlite/sqlite/src/query.rs b/crates/sqlite/sqlite/src/query.rs index cb49bfc1..3f7effb7 100644 --- a/crates/sqlite/sqlite/src/query.rs +++ b/crates/sqlite/sqlite/src/query.rs @@ -1,10 +1,13 @@ use sqlx::{sqlite::SqliteRow, Pool, Sqlite}; +use torii_cache::query_cache::CachedRow; use torii_proto::{OrderBy, OrderDirection, Page, Pagination, PaginationDirection}; use crate::{ + caching_pool::CachingPool, constants::SQL_DEFAULT_LIMIT, cursor::{ - build_cursor_conditions, build_cursor_values, decode_cursor_values, encode_cursor_values, + build_cursor_conditions, build_cursor_values, build_cursor_values_cached, + decode_cursor_values, encode_cursor_values, }, error::Error, }; @@ -215,4 +218,73 @@ impl PaginationExecutor { next_cursor, }) } + + pub async fn execute_paginated_query_cached( + &self, + caching_pool: &CachingPool, + mut query_builder: QueryBuilder, + pagination: &Pagination, + default_order_by: &OrderBy, + ) -> Result, Error> { + let mut pagination = pagination.clone(); + pagination.order_by.push(default_order_by.clone()); + + let original_limit = pagination.limit.unwrap_or(SQL_DEFAULT_LIMIT as u32); + let fetch_limit = original_limit + 1; + + let cursor_values: Option> = pagination + .cursor + .as_ref() + .map(|cursor_str| decode_cursor_values(cursor_str)) + .transpose()?; + + let (cursor_conditions, cursor_binds) = + build_cursor_conditions(&pagination, cursor_values.as_deref())?; + + for condition in cursor_conditions { + query_builder = query_builder.where_clause(&condition); + } + + for bind in cursor_binds { + query_builder = query_builder.bind_value(bind); + } + + for order_by in &pagination.order_by { + let field = format!("[{}]", order_by.field); + let direction = match (&order_by.direction, &pagination.direction) { + (OrderDirection::Asc, PaginationDirection::Forward) => OrderDirection::Asc, + (OrderDirection::Asc, PaginationDirection::Backward) => OrderDirection::Desc, + (OrderDirection::Desc, PaginationDirection::Forward) => OrderDirection::Desc, + (OrderDirection::Desc, PaginationDirection::Backward) => OrderDirection::Asc, + }; + query_builder = query_builder.order_by(&field, direction); + } + + query_builder = query_builder.limit(fetch_limit); + + let bind_values = query_builder.bind_values().to_vec(); + let query = query_builder.build(); + + let mut rows = caching_pool.fetch_all_cached(&query, &bind_values).await?.into_cached_rows(); + let has_more = rows.len() >= fetch_limit as usize; + + if pagination.direction == PaginationDirection::Backward { + rows.reverse(); + } + + let mut next_cursor = None; + + if has_more { + rows.truncate(original_limit as usize); + if let Some(last_row) = rows.last() { + let cursor_values = build_cursor_values_cached(&pagination, last_row)?; + next_cursor = Some(encode_cursor_values(&cursor_values)?); + } + } + + Ok(Page { + items: rows, + next_cursor, + }) + } } diff --git a/crates/sqlite/sqlite/src/storage.rs b/crates/sqlite/sqlite/src/storage.rs index ecf4bf84..8aef6c8e 100644 --- a/crates/sqlite/sqlite/src/storage.rs +++ b/crates/sqlite/sqlite/src/storage.rs @@ -10,6 +10,7 @@ use dojo_world::{config::WorldMetadata, contracts::abigen::model::Layout}; use sqlx::{sqlite::SqliteRow, FromRow, Row}; use starknet::core::types::U256; use starknet_crypto::{poseidon_hash_many, Felt}; +use torii_cache::query_cache::{CachedRow, CachedValue}; use torii_math::I256; use torii_proto::{ schema::Entity, Activity, ActivityQuery, AggregationEntry, AggregationQuery, BalanceId, @@ -40,12 +41,139 @@ use crate::{ error::ExecutorQueryError, ApplyBalanceDiffQuery, Argument, DeleteEntityQuery, EntityQuery, EventMessageQuery, QueryMessage, QueryType, StoreTransactionQuery, UpdateCursorsQuery, }, + error::QueryError, utils::{felt_to_sql_string, felts_to_sql_string, utc_dt_string_from_timestamp}, Sql, }; pub const LOG_TARGET: &str = "torii::sqlite::storage"; +fn cached_value_to_string(value: &CachedValue, column: &str) -> Result { + match value { + CachedValue::Text(text) => Ok(text.clone()), + CachedValue::Integer(int_val) => Ok(int_val.to_string()), + CachedValue::Real(real_val) => Ok(real_val.to_string()), + CachedValue::Blob(_) | CachedValue::Null => Err(Error::Query(QueryError::InvalidCursor( + format!("Unsupported value for column {}", column), + ))), + } +} + +fn cached_row_string(row: &CachedRow, column: &str) -> Result { + let value = row.get(column).ok_or_else(|| { + Error::Query(QueryError::InvalidCursor(format!( + "Missing column {}", + column + ))) + })?; + cached_value_to_string(value, column) +} + +fn cached_row_optional_string(row: &CachedRow, column: &str) -> Result, Error> { + let value = row.get(column).ok_or_else(|| { + Error::Query(QueryError::InvalidCursor(format!( + "Missing column {}", + column + ))) + })?; + + match value { + CachedValue::Null => Ok(None), + _ => cached_value_to_string(value, column).map(Some), + } +} + +fn cached_row_u8(row: &CachedRow, column: &str) -> Result { + let value = row.get(column).ok_or_else(|| { + Error::Query(QueryError::InvalidCursor(format!( + "Missing column {}", + column + ))) + })?; + + match value { + CachedValue::Integer(int_val) => Ok(*int_val as u8), + CachedValue::Text(text) => text.parse::().map_err(|_| { + Error::Query(QueryError::InvalidCursor(format!( + "Invalid u8 value for column {}", + column + ))) + }), + _ => Err(Error::Query(QueryError::InvalidCursor(format!( + "Unsupported u8 value for column {}", + column + )))), + } +} + +fn cached_row_datetime(row: &CachedRow, column: &str) -> Result, Error> { + let value = row.get(column).ok_or_else(|| { + Error::Query(QueryError::InvalidCursor(format!( + "Missing column {}", + column + ))) + })?; + + match value { + CachedValue::Text(text) => DateTime::parse_from_rfc3339(text) + .map(|dt| dt.with_timezone(&Utc)) + .map_err(|_| { + Error::Query(QueryError::InvalidCursor(format!( + "Invalid datetime for column {}", + column + ))) + }), + CachedValue::Integer(int_val) => DateTime::from_timestamp(*int_val, 0).ok_or_else(|| { + Error::Query(QueryError::InvalidCursor(format!( + "Invalid timestamp for column {}", + column + ))) + }), + _ => Err(Error::Query(QueryError::InvalidCursor(format!( + "Unsupported datetime value for column {}", + column + )))), + } +} + +fn token_from_cached_row(row: &CachedRow) -> Result { + Ok(torii_sqlite_types::Token { + id: cached_row_string(row, "id")?, + contract_address: cached_row_string(row, "contract_address")?, + token_id: cached_row_string(row, "token_id")?, + name: cached_row_string(row, "name")?, + symbol: cached_row_string(row, "symbol")?, + decimals: cached_row_u8(row, "decimals")?, + metadata: cached_row_string(row, "metadata")?, + total_supply: cached_row_optional_string(row, "total_supply")?, + }) +} + +fn token_balance_from_cached_row(row: &CachedRow) -> Result { + Ok(torii_sqlite_types::TokenBalance { + id: cached_row_string(row, "id")?, + balance: cached_row_string(row, "balance")?, + account_address: cached_row_string(row, "account_address")?, + contract_address: cached_row_string(row, "contract_address")?, + token_id: cached_row_string(row, "token_id")?, + }) +} + +fn token_transfer_from_cached_row( + row: &CachedRow, +) -> Result { + Ok(torii_sqlite_types::TokenTransfer { + id: cached_row_string(row, "id")?, + contract_address: cached_row_string(row, "contract_address")?, + from_address: cached_row_string(row, "from_address")?, + to_address: cached_row_string(row, "to_address")?, + amount: cached_row_string(row, "amount")?, + token_id: cached_row_string(row, "token_id")?, + executed_at: cached_row_datetime(row, "executed_at")?, + event_id: cached_row_optional_string(row, "event_id")?, + }) +} + #[async_trait] impl ReadOnlyStorage for Sql { fn as_read_only(&self) -> &dyn ReadOnlyStorage { @@ -313,6 +441,30 @@ impl ReadOnlyStorage for Sql { query_builder = query_builder.where_clause(&where_conditions.join(" AND ").to_string()); } + if self.query_cache.is_some() { + let caching_pool = self.caching_pool(); + let page = executor + .execute_paginated_query_cached( + &caching_pool, + query_builder, + &query.pagination, + &OrderBy { + field: "ordering".to_string(), + direction: OrderDirection::Desc, + }, + ) + .await?; + let items: Vec = page + .items + .into_iter() + .map(|row| Result::::Ok(token_from_cached_row(&row)?.into())) + .collect::, _>>()?; + return Ok(Page { + items, + next_cursor: page.next_cursor, + }); + } + let page = executor .execute_paginated_query( query_builder, @@ -373,6 +525,34 @@ impl ReadOnlyStorage for Sql { } } + if self.query_cache.is_some() { + let caching_pool = self.caching_pool(); + let page = executor + .execute_paginated_query_cached( + &caching_pool, + query_builder, + &query.pagination, + &OrderBy { + field: "id".to_string(), + direction: OrderDirection::Desc, + }, + ) + .await?; + let items: Vec = page + .items + .into_iter() + .map(|row| { + Result::::Ok( + token_balance_from_cached_row(&row)?.into(), + ) + }) + .collect::, _>>()?; + return Ok(Page { + items, + next_cursor: page.next_cursor, + }); + } + let page = executor .execute_paginated_query( query_builder, @@ -712,6 +892,36 @@ impl ReadOnlyStorage for Sql { } } + if self.query_cache.is_some() { + let caching_pool = self.caching_pool(); + let page = executor + .execute_paginated_query_cached( + &caching_pool, + query_builder, + &query.pagination, + &OrderBy { + field: "id".to_string(), + direction: OrderDirection::Desc, + }, + ) + .await?; + + let items: Vec = page + .items + .into_iter() + .map(|row| { + Result::::Ok( + token_transfer_from_cached_row(&row)?.into(), + ) + }) + .collect::, _>>()?; + + return Ok(Page { + items, + next_cursor: page.next_cursor, + }); + } + let page = executor .execute_paginated_query( query_builder, diff --git a/crates/sqlite/sqlite/src/utils.rs b/crates/sqlite/sqlite/src/utils.rs index 1e5d91b9..55a6734d 100644 --- a/crates/sqlite/sqlite/src/utils.rs +++ b/crates/sqlite/sqlite/src/utils.rs @@ -6,6 +6,7 @@ use chrono::{DateTime, Utc}; use sqlx::{Column, Row, TypeInfo}; use starknet::core::types::U256; use starknet_crypto::Felt; +use torii_cache::query_cache::{CachedRow, CachedValue}; use crate::constants::SQL_FELT_DELIMITER; @@ -217,6 +218,37 @@ pub fn map_row_to_proto(row: &sqlx::sqlite::SqliteRow) -> torii_proto::proto::ty torii_proto::proto::types::SqlRow { fields } } +// Map a cached row to proto SqlRow type +pub fn map_cached_row_to_proto(row: &CachedRow) -> torii_proto::proto::types::SqlRow { + use std::collections::HashMap; + use torii_proto::proto::types::{sql_value, SqlValue}; + + let mut fields = HashMap::new(); + + for (column, value) in row.columns.iter().zip(row.values.iter()) { + let value = match value { + CachedValue::Null => SqlValue { + value_type: Some(sql_value::ValueType::Null(true)), + }, + CachedValue::Integer(int_val) => SqlValue { + value_type: Some(sql_value::ValueType::Integer(*int_val)), + }, + CachedValue::Real(real_val) => SqlValue { + value_type: Some(sql_value::ValueType::Real(*real_val)), + }, + CachedValue::Text(text) => SqlValue { + value_type: Some(sql_value::ValueType::Text(text.clone())), + }, + CachedValue::Blob(blob) => SqlValue { + value_type: Some(sql_value::ValueType::Blob(blob.clone())), + }, + }; + fields.insert(column.name.clone(), value); + } + + torii_proto::proto::types::SqlRow { fields } +} + #[cfg(test)] mod tests { use chrono::{DateTime, NaiveDate, NaiveTime, Utc};