diff --git a/.claude/settings.local.json b/.claude/settings.local.json index 4cf35026..449c11a2 100644 --- a/.claude/settings.local.json +++ b/.claude/settings.local.json @@ -111,7 +111,8 @@ "Bash(python3 -c \"import sys,json; d=json.load\\(sys.stdin\\); [print\\(p[''''name''''], p[''''version'''']\\) for p in d[''''packages''''] if p[''''name'''']==''''rmcp'''']\")", "Bash(python3 -c \"import json,sys; d=json.load\\(sys.stdin\\); [print\\(p[''''name''''], p[''''version'''']\\) for p in d[''''packages''''] if p[''''name'''']==''''rmcp'''']\")", "Bash(git push:*)", - "Bash(gh pr:*)" + "Bash(gh pr:*)", + "Bash(do head:*)" ], "deny": [] } diff --git a/Cargo.toml b/Cargo.toml index 571d339d..9d85c69e 100644 --- a/Cargo.toml +++ b/Cargo.toml @@ -26,7 +26,7 @@ resolver = "2" [workspace.package] version = "0.17.0" -rust-version = "1.88" +rust-version = "1.89" edition = "2024" license = "MIT OR Apache-2.0" authors = ["PulseEngine Contributors"] diff --git a/conformance-tests/src/main.rs b/conformance-tests/src/main.rs index 7f26bd00..ddaf13e5 100644 --- a/conformance-tests/src/main.rs +++ b/conformance-tests/src/main.rs @@ -163,14 +163,14 @@ fn list_servers() -> Result<()> { let entry = entry?; let path = entry.path(); - if path.extension().and_then(|s| s.to_str()) == Some("json") { - if let Some(name) = path.file_stem().and_then(|s| s.to_str()) { - // Try to load config to get description - if let Ok(config) = ServerConfig::load(&path) { - println!(" {} - {}", name.green(), config.description); - } else { - println!(" {name}"); - } + if path.extension().and_then(|s| s.to_str()) == Some("json") + && let Some(name) = path.file_stem().and_then(|s| s.to_str()) + { + // Try to load config to get description + if let Ok(config) = ServerConfig::load(&path) { + println!(" {} - {}", name.green(), config.description); + } else { + println!(" {name}"); } } } diff --git a/conformance-tests/src/runner.rs b/conformance-tests/src/runner.rs index 8b7af1bb..606e4d4d 100644 --- a/conformance-tests/src/runner.rs +++ b/conformance-tests/src/runner.rs @@ -107,10 +107,10 @@ impl ConformanceRunner { // Copy conformance results if they exist let conformance_results = PathBuf::from("results"); - if conformance_results.exists() { - if let Err(e) = copy_dir_all(&conformance_results, &self.results_dir) { - eprintln!("{} Failed to copy conformance results: {}", "⚠".yellow(), e); - } + if conformance_results.exists() + && let Err(e) = copy_dir_all(&conformance_results, &self.results_dir) + { + eprintln!("{} Failed to copy conformance results: {}", "⚠".yellow(), e); } // Generate summary @@ -172,12 +172,11 @@ impl ConformanceRunner { if failures > 0 { println!("{} Failed checks:", "⚠".yellow()); for check in checks_array { - if check["status"] == "FAILURE" { - if let (Some(name), Some(desc)) = + if check["status"] == "FAILURE" + && let (Some(name), Some(desc)) = (check["name"].as_str(), check["description"].as_str()) - { - println!(" - {name}: {desc}"); - } + { + println!(" - {name}: {desc}"); } } } diff --git a/integration-tests/src/cli_server_integration.rs b/integration-tests/src/cli_server_integration.rs index f5e5240d..f5c61ea8 100644 --- a/integration-tests/src/cli_server_integration.rs +++ b/integration-tests/src/cli_server_integration.rs @@ -1,3 +1,4 @@ +#![allow(dead_code)] //! Integration tests for CLI and server interaction use crate::test_utils::*; diff --git a/integration-tests/src/end_to_end_scenarios.rs b/integration-tests/src/end_to_end_scenarios.rs index 6a121a49..affc71e2 100644 --- a/integration-tests/src/end_to_end_scenarios.rs +++ b/integration-tests/src/end_to_end_scenarios.rs @@ -1,3 +1,4 @@ +#![allow(dead_code)] //! End-to-end integration scenarios that test the complete MCP framework use crate::test_utils::*; diff --git a/integration-tests/src/monitoring_integration.rs b/integration-tests/src/monitoring_integration.rs index 2eba1e5b..4651b8b5 100644 --- a/integration-tests/src/monitoring_integration.rs +++ b/integration-tests/src/monitoring_integration.rs @@ -1,3 +1,4 @@ +#![allow(dead_code)] //! Integration tests for monitoring across multiple components use crate::test_utils::*; diff --git a/integration-tests/src/transport_server_integration.rs b/integration-tests/src/transport_server_integration.rs index 451547b2..fb0641e4 100644 --- a/integration-tests/src/transport_server_integration.rs +++ b/integration-tests/src/transport_server_integration.rs @@ -1,3 +1,4 @@ +#![allow(dead_code)] //! Integration tests for transport and server interaction use crate::test_utils::*; diff --git a/mcp-external-validation/src/config.rs b/mcp-external-validation/src/config.rs index 421b2c87..68e39627 100644 --- a/mcp-external-validation/src/config.rs +++ b/mcp-external-validation/src/config.rs @@ -295,12 +295,12 @@ impl ValidationConfig { }); } - if let Some(ref url) = self.jsonrpc.validator_url { - if !url.starts_with("http") { - return Err(ValidationError::ConfigurationError { - message: "JSON-RPC validator URL must start with http or https".to_string(), - }); - } + if let Some(ref url) = self.jsonrpc.validator_url + && !url.starts_with("http") + { + return Err(ValidationError::ConfigurationError { + message: "JSON-RPC validator URL must start with http or https".to_string(), + }); } // Validate port ranges diff --git a/mcp-external-validation/src/cross_language.rs b/mcp-external-validation/src/cross_language.rs index c6484475..27bc414c 100644 --- a/mcp-external-validation/src/cross_language.rs +++ b/mcp-external-validation/src/cross_language.rs @@ -216,25 +216,25 @@ impl CrossLanguageTester { let python_commands = ["python3", "python"]; for cmd in &python_commands { - if let Ok(output) = Command::new(cmd).arg("--version").output() { - if output.status.success() { - let version = String::from_utf8_lossy(&output.stdout).trim().to_string(); - - // Check if MCP package is available - let sdk_check = Command::new(cmd) - .args(&["-c", "import mcp; print('available')"]) - .output(); - - let sdk_available = sdk_check.map(|o| o.status.success()).unwrap_or(false); - - return Ok(LanguageRuntime { - language: Language::Python, - executable: cmd.to_string(), - version, - sdk_available, - test_scripts_dir: Some(std::env::temp_dir().join("mcp_python_cross_tests")), - }); - } + if let Ok(output) = Command::new(cmd).arg("--version").output() + && output.status.success() + { + let version = String::from_utf8_lossy(&output.stdout).trim().to_string(); + + // Check if MCP package is available + let sdk_check = Command::new(cmd) + .args(&["-c", "import mcp; print('available')"]) + .output(); + + let sdk_available = sdk_check.map(|o| o.status.success()).unwrap_or(false); + + return Ok(LanguageRuntime { + language: Language::Python, + executable: cmd.to_string(), + version, + sdk_available, + test_scripts_dir: Some(std::env::temp_dir().join("mcp_python_cross_tests")), + }); } } @@ -828,17 +828,17 @@ async function testMcpCrossLanguage(serverUrl, protocolVersion) {{ if needs_setup { match language { Language::Python => { - if let Some(runtime) = self.available_languages.get_mut(&language) { - if let Err(e) = Self::setup_python_environment(runtime).await { - warn!("Failed to setup Python environment: {}", e); - } + if let Some(runtime) = self.available_languages.get_mut(&language) + && let Err(e) = Self::setup_python_environment(runtime).await + { + warn!("Failed to setup Python environment: {}", e); } } Language::JavaScript | Language::TypeScript => { - if let Some(runtime) = self.available_languages.get_mut(&language) { - if let Err(e) = Self::setup_javascript_environment(runtime).await { - warn!("Failed to setup JavaScript environment: {}", e); - } + if let Some(runtime) = self.available_languages.get_mut(&language) + && let Err(e) = Self::setup_javascript_environment(runtime).await + { + warn!("Failed to setup JavaScript environment: {}", e); } } _ => { diff --git a/mcp-external-validation/src/ecosystem.rs b/mcp-external-validation/src/ecosystem.rs index a5d41350..524c3854 100644 --- a/mcp-external-validation/src/ecosystem.rs +++ b/mcp-external-validation/src/ecosystem.rs @@ -280,17 +280,17 @@ impl EcosystemTester { /// Detect Cline CLI fn detect_cline(&self) -> ValidationResult { - if let Ok(output) = Command::new("cline").arg("--version").output() { - if output.status.success() { - let version = String::from_utf8_lossy(&output.stdout).trim().to_string(); - return Ok(ComponentInfo { - component: EcosystemComponent::Cline, - available: true, - version: Some(version), - location: None, - metadata: HashMap::new(), - }); - } + if let Ok(output) = Command::new("cline").arg("--version").output() + && output.status.success() + { + let version = String::from_utf8_lossy(&output.stdout).trim().to_string(); + return Ok(ComponentInfo { + component: EcosystemComponent::Cline, + available: true, + version: Some(version), + location: None, + metadata: HashMap::new(), + }); } Err(ValidationError::ConfigurationError { diff --git a/mcp-external-validation/src/inspector.rs b/mcp-external-validation/src/inspector.rs index 7290366b..b9e18c9d 100644 --- a/mcp-external-validation/src/inspector.rs +++ b/mcp-external-validation/src/inspector.rs @@ -419,10 +419,10 @@ impl InspectorClient { fn extract_session_token(&self, stderr: &str) -> Option { // Look for lines like "🔑 Session token: 3a1c267fad21f7150b7d624c..." for line in stderr.lines() { - if line.contains("Session token:") { - if let Some(token_part) = line.split("Session token:").nth(1) { - return Some(token_part.trim().to_string()); - } + if line.contains("Session token:") + && let Some(token_part) = line.split("Session token:").nth(1) + { + return Some(token_part.trim().to_string()); } } None diff --git a/mcp-external-validation/src/jsonrpc.rs b/mcp-external-validation/src/jsonrpc.rs index a377d477..f5626982 100644 --- a/mcp-external-validation/src/jsonrpc.rs +++ b/mcp-external-validation/src/jsonrpc.rs @@ -332,15 +332,15 @@ impl JsonRpcValidator { } // Check for empty string IDs - if let Some(id_str) = id.as_str() { - if id_str.is_empty() { - issues.push(ValidationIssue::new( - IssueSeverity::Warning, - "id_format".to_string(), - format!("Message {}: Empty string ID is not recommended", index), - "jsonrpc-validator".to_string(), - )); - } + if let Some(id_str) = id.as_str() + && id_str.is_empty() + { + issues.push(ValidationIssue::new( + IssueSeverity::Warning, + "id_format".to_string(), + format!("Message {}: Empty string ID is not recommended", index), + "jsonrpc-validator".to_string(), + )); } } } @@ -352,43 +352,43 @@ impl JsonRpcValidator { index: usize, issues: &mut Vec, ) { - if let Some(method) = message.get("method") { - if let Some(method_str) = method.as_str() { - // Check for reserved method names (starting with rpc.) - if method_str.starts_with("rpc.") && !self.is_allowed_rpc_method(method_str) { - issues.push(ValidationIssue::new( - IssueSeverity::Error, - "method_naming".to_string(), - format!( - "Message {}: Method name '{}' is reserved", - index, method_str - ), - "jsonrpc-validator".to_string(), - )); - } + if let Some(method) = message.get("method") + && let Some(method_str) = method.as_str() + { + // Check for reserved method names (starting with rpc.) + if method_str.starts_with("rpc.") && !self.is_allowed_rpc_method(method_str) { + issues.push(ValidationIssue::new( + IssueSeverity::Error, + "method_naming".to_string(), + format!( + "Message {}: Method name '{}' is reserved", + index, method_str + ), + "jsonrpc-validator".to_string(), + )); + } - // Check for method name conventions - if method_str.is_empty() { - issues.push(ValidationIssue::new( - IssueSeverity::Error, - "method_naming".to_string(), - format!("Message {}: Method name cannot be empty", index), - "jsonrpc-validator".to_string(), - )); - } + // Check for method name conventions + if method_str.is_empty() { + issues.push(ValidationIssue::new( + IssueSeverity::Error, + "method_naming".to_string(), + format!("Message {}: Method name cannot be empty", index), + "jsonrpc-validator".to_string(), + )); + } - // Check for non-ASCII characters - if !method_str.is_ascii() { - issues.push(ValidationIssue::new( - IssueSeverity::Warning, - "method_naming".to_string(), - format!( - "Message {}: Method name contains non-ASCII characters", - index - ), - "jsonrpc-validator".to_string(), - )); - } + // Check for non-ASCII characters + if !method_str.is_ascii() { + issues.push(ValidationIssue::new( + IssueSeverity::Warning, + "method_naming".to_string(), + format!( + "Message {}: Method name contains non-ASCII characters", + index + ), + "jsonrpc-validator".to_string(), + )); } } } @@ -434,21 +434,21 @@ impl JsonRpcValidator { /// Check error codes fn check_error_codes(&self, message: &Value, index: usize, issues: &mut Vec) { - if let Some(error) = message.get("error") { - if let Some(error_obj) = error.as_object() { - if let Some(code) = error_obj.get("code") { - if let Some(code_num) = code.as_i64() { - if !self.is_valid_error_code(code_num) { - issues.push(ValidationIssue::new( - IssueSeverity::Warning, - "error_codes".to_string(), - format!("Message {}: Error code {} is not a standard JSON-RPC error code", index, code_num), - "jsonrpc-validator".to_string(), - )); - } - } - } - } + if let Some(error) = message.get("error") + && let Some(error_obj) = error.as_object() + && let Some(code) = error_obj.get("code") + && let Some(code_num) = code.as_i64() + && !self.is_valid_error_code(code_num) + { + issues.push(ValidationIssue::new( + IssueSeverity::Warning, + "error_codes".to_string(), + format!( + "Message {}: Error code {} is not a standard JSON-RPC error code", + index, code_num + ), + "jsonrpc-validator".to_string(), + )); } } @@ -801,10 +801,10 @@ impl JsonRpcValidator { // Parse stdout for JSON-RPC responses let stdout_str = String::from_utf8_lossy(&output.stdout); for line in stdout_str.lines() { - if let Ok(parsed) = serde_json::from_str::(line) { - if self.is_jsonrpc_message(&parsed) { - messages.push(parsed); - } + if let Ok(parsed) = serde_json::from_str::(line) + && self.is_jsonrpc_message(&parsed) + { + messages.push(parsed); } } } @@ -871,10 +871,10 @@ impl JsonRpcValidator { // Send request and collect response match client.post(server_url).json(&request).send().await { Ok(response) => { - if response.status().is_success() { - if let Ok(response_json) = response.json::().await { - messages.push(response_json); - } + if response.status().is_success() + && let Ok(response_json) = response.json::().await + { + messages.push(response_json); } } Err(_) => { diff --git a/mcp-external-validation/src/mcp_semantic.rs b/mcp-external-validation/src/mcp_semantic.rs index 71f0d9f2..03c3bbe1 100644 --- a/mcp-external-validation/src/mcp_semantic.rs +++ b/mcp-external-validation/src/mcp_semantic.rs @@ -303,18 +303,18 @@ impl McpSemanticValidator { self.validate_required_field(params, "clientInfo", index, result)?; // Validate protocol version - if let Some(version) = params.get("protocolVersion").and_then(|v| v.as_str()) { - if !self.is_supported_protocol_version(version) { - result.issues.push(ValidationIssue::new( - IssueSeverity::Warning, - "initialization".to_string(), - format!( - "Message {}: Unsupported protocol version: {}", - index, version - ), - "mcp-semantic".to_string(), - )); - } + if let Some(version) = params.get("protocolVersion").and_then(|v| v.as_str()) + && !self.is_supported_protocol_version(version) + { + result.issues.push(ValidationIssue::new( + IssueSeverity::Warning, + "initialization".to_string(), + format!( + "Message {}: Unsupported protocol version: {}", + index, version + ), + "mcp-semantic".to_string(), + )); } // Store client capabilities @@ -424,15 +424,15 @@ impl McpSemanticValidator { self.validate_required_field(error, "message", index, result)?; // Validate error codes are within MCP specification - if let Some(code) = error.get("code").and_then(|c| c.as_i64()) { - if !self.is_valid_mcp_error_code(code) { - result.issues.push(ValidationIssue::new( - IssueSeverity::Warning, - "error_handling".to_string(), - format!("Message {}: Non-standard MCP error code: {}", index, code), - "mcp-semantic".to_string(), - )); - } + if let Some(code) = error.get("code").and_then(|c| c.as_i64()) + && !self.is_valid_mcp_error_code(code) + { + result.issues.push(ValidationIssue::new( + IssueSeverity::Warning, + "error_handling".to_string(), + format!("Message {}: Non-standard MCP error code: {}", index, code), + "mcp-semantic".to_string(), + )); } } @@ -459,18 +459,18 @@ impl McpSemanticValidator { // Check that no non-initialization messages come before initialize if let Some(init_index) = first_initialize_index { for (i, msg) in self.message_sequence.iter().enumerate() { - if i < init_index && msg.message_type == MessageType::Request { - if let Some(method) = &msg.method { - if method != "initialize" { - result.issues.push(ValidationIssue::new( - IssueSeverity::Error, - "state_transitions".to_string(), - format!("Method '{}' called before initialization", method), - "mcp-semantic".to_string(), - )); - return Ok(()); - } - } + if i < init_index + && msg.message_type == MessageType::Request + && let Some(method) = &msg.method + && method != "initialize" + { + result.issues.push(ValidationIssue::new( + IssueSeverity::Error, + "state_transitions".to_string(), + format!("Method '{}' called before initialization", method), + "mcp-semantic".to_string(), + )); + return Ok(()); } } } @@ -687,18 +687,18 @@ impl McpSemanticValidator { } // Validate URI is provided - if let Some(params) = message.get("params") { - if params.get("uri").is_none() { - result.issues.push(ValidationIssue::new( - IssueSeverity::Error, - "method_compliance".to_string(), - format!( - "Message {}: resources/read missing required 'uri' parameter", - index - ), - "mcp-semantic".to_string(), - )); - } + if let Some(params) = message.get("params") + && params.get("uri").is_none() + { + result.issues.push(ValidationIssue::new( + IssueSeverity::Error, + "method_compliance".to_string(), + format!( + "Message {}: resources/read missing required 'uri' parameter", + index + ), + "mcp-semantic".to_string(), + )); } Ok(()) } @@ -743,18 +743,18 @@ impl McpSemanticValidator { } // Validate name is provided - if let Some(params) = message.get("params") { - if params.get("name").is_none() { - result.issues.push(ValidationIssue::new( - IssueSeverity::Error, - "method_compliance".to_string(), - format!( - "Message {}: prompts/get missing required 'name' parameter", - index - ), - "mcp-semantic".to_string(), - )); - } + if let Some(params) = message.get("params") + && params.get("name").is_none() + { + result.issues.push(ValidationIssue::new( + IssueSeverity::Error, + "method_compliance".to_string(), + format!( + "Message {}: prompts/get missing required 'name' parameter", + index + ), + "mcp-semantic".to_string(), + )); } Ok(()) } @@ -776,17 +776,16 @@ impl McpSemanticValidator { result: &mut McpSemanticResult, ) -> ValidationResult<()> { // Validate log level if provided - if let Some(params) = message.get("params") { - if let Some(level) = params.get("level").and_then(|l| l.as_str()) { - if !matches!(level, "debug" | "info" | "warning" | "error") { - result.issues.push(ValidationIssue::new( - IssueSeverity::Warning, - "method_compliance".to_string(), - format!("Message {}: Invalid logging level '{}'", index, level), - "mcp-semantic".to_string(), - )); - } - } + if let Some(params) = message.get("params") + && let Some(level) = params.get("level").and_then(|l| l.as_str()) + && !matches!(level, "debug" | "info" | "warning" | "error") + { + result.issues.push(ValidationIssue::new( + IssueSeverity::Warning, + "method_compliance".to_string(), + format!("Message {}: Invalid logging level '{}'", index, level), + "mcp-semantic".to_string(), + )); } Ok(()) } @@ -867,12 +866,12 @@ impl McpSemanticValidator { _index: usize, _result: &mut McpSemanticResult, ) -> ValidationResult<()> { - if let Some(response_result) = message.get("result") { - if let Some(tools) = response_result.get("tools").and_then(|t| t.as_array()) { - for tool in tools { - if let Some(name) = tool.get("name").and_then(|n| n.as_str()) { - self.available_tools.insert(name.to_string()); - } + if let Some(response_result) = message.get("result") + && let Some(tools) = response_result.get("tools").and_then(|t| t.as_array()) + { + for tool in tools { + if let Some(name) = tool.get("name").and_then(|n| n.as_str()) { + self.available_tools.insert(name.to_string()); } } } @@ -885,12 +884,12 @@ impl McpSemanticValidator { _index: usize, _result: &mut McpSemanticResult, ) -> ValidationResult<()> { - if let Some(response_result) = message.get("result") { - if let Some(resources) = response_result.get("resources").and_then(|r| r.as_array()) { - for resource in resources { - if let Some(uri) = resource.get("uri").and_then(|u| u.as_str()) { - self.available_resources.insert(uri.to_string()); - } + if let Some(response_result) = message.get("result") + && let Some(resources) = response_result.get("resources").and_then(|r| r.as_array()) + { + for resource in resources { + if let Some(uri) = resource.get("uri").and_then(|u| u.as_str()) { + self.available_resources.insert(uri.to_string()); } } } @@ -903,12 +902,12 @@ impl McpSemanticValidator { _index: usize, _result: &mut McpSemanticResult, ) -> ValidationResult<()> { - if let Some(response_result) = message.get("result") { - if let Some(prompts) = response_result.get("prompts").and_then(|p| p.as_array()) { - for prompt in prompts { - if let Some(name) = prompt.get("name").and_then(|n| n.as_str()) { - self.available_prompts.insert(name.to_string()); - } + if let Some(response_result) = message.get("result") + && let Some(prompts) = response_result.get("prompts").and_then(|p| p.as_array()) + { + for prompt in prompts { + if let Some(name) = prompt.get("name").and_then(|n| n.as_str()) { + self.available_prompts.insert(name.to_string()); } } } @@ -943,18 +942,18 @@ impl McpSemanticValidator { result: &mut McpSemanticResult, ) -> ValidationResult<()> { // Validate request ID is provided - if let Some(params) = message.get("params") { - if params.get("requestId").is_none() { - result.issues.push(ValidationIssue::new( - IssueSeverity::Error, - "method_compliance".to_string(), - format!( - "Message {}: cancelled notification missing 'requestId'", - index - ), - "mcp-semantic".to_string(), - )); - } + if let Some(params) = message.get("params") + && params.get("requestId").is_none() + { + result.issues.push(ValidationIssue::new( + IssueSeverity::Error, + "method_compliance".to_string(), + format!( + "Message {}: cancelled notification missing 'requestId'", + index + ), + "mcp-semantic".to_string(), + )); } Ok(()) } @@ -966,18 +965,18 @@ impl McpSemanticValidator { result: &mut McpSemanticResult, ) -> ValidationResult<()> { // Validate progress token and value - if let Some(params) = message.get("params") { - if params.get("progressToken").is_none() { - result.issues.push(ValidationIssue::new( - IssueSeverity::Error, - "method_compliance".to_string(), - format!( - "Message {}: progress notification missing 'progressToken'", - index - ), - "mcp-semantic".to_string(), - )); - } + if let Some(params) = message.get("params") + && params.get("progressToken").is_none() + { + result.issues.push(ValidationIssue::new( + IssueSeverity::Error, + "method_compliance".to_string(), + format!( + "Message {}: progress notification missing 'progressToken'", + index + ), + "mcp-semantic".to_string(), + )); } Ok(()) } diff --git a/mcp-external-validation/src/python_sdk.rs b/mcp-external-validation/src/python_sdk.rs index 597331d4..54956b6e 100644 --- a/mcp-external-validation/src/python_sdk.rs +++ b/mcp-external-validation/src/python_sdk.rs @@ -210,10 +210,10 @@ impl PythonSdkTester { for cmd in &python_commands { let output = Command::new(cmd).arg("--version").output(); - if let Ok(output) = output { - if output.status.success() { - return Ok(cmd.to_string()); - } + if let Ok(output) = output + && output.status.success() + { + return Ok(cmd.to_string()); } } diff --git a/mcp-external-validation/src/security.rs b/mcp-external-validation/src/security.rs index 9d164c7a..c738c4af 100644 --- a/mcp-external-validation/src/security.rs +++ b/mcp-external-validation/src/security.rs @@ -450,18 +450,17 @@ impl SecurityTester { { Ok(output) if output.status.success() => { // Parse the JSON output to check completeness - if let Ok(completeness_str) = String::from_utf8(output.stdout) { - if let Ok(completeness) = + if let Ok(completeness_str) = String::from_utf8(output.stdout) + && let Ok(completeness) = serde_json::from_str::(&completeness_str) - { - if let Some(production_ready) = completeness - .get("production_ready") - .and_then(|v| v.as_bool()) - { - if production_ready { - // Framework has complete API key management - result.authentication.passed += 1; - result.issues.push(ValidationIssue::new( + && let Some(production_ready) = completeness + .get("production_ready") + .and_then(|v| v.as_bool()) + && production_ready + { + // Framework has complete API key management + result.authentication.passed += 1; + result.issues.push(ValidationIssue::new( IssueSeverity::Info, "framework-auth".to_string(), "✅ Authentication Framework Complete: pulseengine_auth has full API key management capabilities".to_string(), @@ -486,10 +485,7 @@ impl SecurityTester { "Bulk operations" ]) )); - return; - } - } - } + return; } } _ => { diff --git a/mcp-macros/src/mcp_tool.rs b/mcp-macros/src/mcp_tool.rs index 1006017d..3e59fa7f 100644 --- a/mcp-macros/src/mcp_tool.rs +++ b/mcp-macros/src/mcp_tool.rs @@ -449,26 +449,22 @@ fn generate_matchit_resource_impl( /// Extract URI template from mcp_resource attribute fn extract_uri_template_from_attr(attrs: &[syn::Attribute]) -> Option { for attr in attrs { - if attr.path().is_ident("mcp_resource") { - if let Ok(meta_list) = attr.meta.require_list() { - // Parse the meta list tokens properly - if let Ok(nested_meta) = - darling::ast::NestedMeta::parse_meta_list(meta_list.tokens.clone()) - { - for nested in nested_meta { - if let darling::ast::NestedMeta::Meta(syn::Meta::NameValue(name_value)) = - nested - { - if name_value.path.is_ident("uri_template") { - if let syn::Expr::Lit(syn::ExprLit { - lit: syn::Lit::Str(lit_str), - .. - }) = name_value.value - { - return Some(lit_str.value()); - } - } - } + if attr.path().is_ident("mcp_resource") + && let Ok(meta_list) = attr.meta.require_list() + { + // Parse the meta list tokens properly + if let Ok(nested_meta) = + darling::ast::NestedMeta::parse_meta_list(meta_list.tokens.clone()) + { + for nested in nested_meta { + if let darling::ast::NestedMeta::Meta(syn::Meta::NameValue(name_value)) = nested + && name_value.path.is_ident("uri_template") + && let syn::Expr::Lit(syn::ExprLit { + lit: syn::Lit::Str(lit_str), + .. + }) = name_value.value + { + return Some(lit_str.value()); } } } @@ -1038,16 +1034,13 @@ fn generate_multi_parameter_schema(params: &[&syn::PatType]) -> syn::Result and extract the inner type T fn extract_option_inner_type(ty: &syn::Type) -> (bool, &syn::Type) { - if let syn::Type::Path(type_path) = ty { - if let Some(segment) = type_path.path.segments.last() { - if segment.ident == "Option" { - if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { - if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() { - return (true, inner_ty); - } - } - } - } + if let syn::Type::Path(type_path) = ty + && let Some(segment) = type_path.path.segments.last() + && segment.ident == "Option" + && let syn::PathArguments::AngleBracketed(args) = &segment.arguments + && let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() + { + return (true, inner_ty); } (false, ty) } @@ -1082,26 +1075,24 @@ fn is_tool_context_type(ty: &syn::Type) -> bool { syn::Type::Path(type_path) => { if let Some(segment) = type_path.path.segments.last() { // Check if it's a wrapper like Arc, Box, Rc containing dyn ToolContext - if matches!(segment.ident.to_string().as_str(), "Arc" | "Box" | "Rc") { - if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { - for arg in &args.args { - // Check for dyn ToolContext (collapsed pattern) - if let syn::GenericArgument::Type(syn::Type::TraitObject(trait_obj)) = - arg - { - return trait_obj.bounds.iter().any(|bound| { - if let syn::TypeParamBound::Trait(trait_bound) = bound { - trait_bound - .path - .segments - .last() - .map(|seg| seg.ident == "ToolContext") - .unwrap_or(false) - } else { - false - } - }); - } + if matches!(segment.ident.to_string().as_str(), "Arc" | "Box" | "Rc") + && let syn::PathArguments::AngleBracketed(args) = &segment.arguments + { + for arg in &args.args { + // Check for dyn ToolContext (collapsed pattern) + if let syn::GenericArgument::Type(syn::Type::TraitObject(trait_obj)) = arg { + return trait_obj.bounds.iter().any(|bound| { + if let syn::TypeParamBound::Trait(trait_bound) = bound { + trait_bound + .path + .segments + .last() + .map(|seg| seg.ident == "ToolContext") + .unwrap_or(false) + } else { + false + } + }); } } } @@ -1185,11 +1176,11 @@ fn generate_type_schema_for_type(ty: &syn::Type) -> TokenStream { "bool" => quote! { ::serde_json::json!({"type": "boolean"}) }, "Vec" => { // Handle Vec - extract T and create array schema - if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { - if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() { - let items_schema = generate_type_schema_for_type(inner_ty); - return quote! { ::serde_json::json!({"type": "array", "items": #items_schema}) }; - } + if let syn::PathArguments::AngleBracketed(args) = &segment.arguments + && let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() + { + let items_schema = generate_type_schema_for_type(inner_ty); + return quote! { ::serde_json::json!({"type": "array", "items": #items_schema}) }; } quote! { ::serde_json::json!({"type": "array"}) } } diff --git a/mcp-macros/src/utils.rs b/mcp-macros/src/utils.rs index 558bc36d..4c96c8ef 100644 --- a/mcp-macros/src/utils.rs +++ b/mcp-macros/src/utils.rs @@ -67,16 +67,14 @@ pub fn extract_doc_comment(attrs: &[Attribute]) -> Option { let mut docs = Vec::new(); for attr in attrs { - if let Meta::NameValue(meta) = &attr.meta { - if meta.path.is_ident("doc") { - if let Expr::Lit(expr_lit) = &meta.value { - if let Lit::Str(lit_str) = &expr_lit.lit { - let content = lit_str.value().trim().to_string(); - if !content.is_empty() { - docs.push(content); - } - } - } + if let Meta::NameValue(meta) = &attr.meta + && meta.path.is_ident("doc") + && let Expr::Lit(expr_lit) = &meta.value + && let Lit::Str(lit_str) = &expr_lit.lit + { + let content = lit_str.value().trim().to_string(); + if !content.is_empty() { + docs.push(content); } } } @@ -114,10 +112,10 @@ pub fn generate_tool_id(base_name: &str) -> syn::Ident { /// Check if a type is an Option pub fn is_option_type(ty: &syn::Type) -> bool { - if let syn::Type::Path(type_path) = ty { - if let Some(segment) = type_path.path.segments.last() { - return segment.ident == "Option"; - } + if let syn::Type::Path(type_path) = ty + && let Some(segment) = type_path.path.segments.last() + { + return segment.ident == "Option"; } false } @@ -125,16 +123,13 @@ pub fn is_option_type(ty: &syn::Type) -> bool { /// Extract the inner type from Option #[allow(dead_code)] pub fn extract_option_inner_type(ty: &syn::Type) -> Option<&syn::Type> { - if let syn::Type::Path(type_path) = ty { - if let Some(segment) = type_path.path.segments.last() { - if segment.ident == "Option" { - if let syn::PathArguments::AngleBracketed(args) = &segment.arguments { - if let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() { - return Some(inner_ty); - } - } - } - } + if let syn::Type::Path(type_path) = ty + && let Some(segment) = type_path.path.segments.last() + && segment.ident == "Option" + && let syn::PathArguments::AngleBracketed(args) = &segment.arguments + && let Some(syn::GenericArgument::Type(inner_ty)) = args.args.first() + { + return Some(inner_ty); } None } @@ -154,40 +149,39 @@ pub fn generate_error_handling(return_type: &syn::ReturnType) -> TokenStream { } syn::ReturnType::Type(_, ty) => { // Check if it's a Result type - if let syn::Type::Path(type_path) = &**ty { - if let Some(segment) = type_path.path.segments.last() { - if segment.ident == "Result" { - // It's already a Result, wrap it properly for the dispatch context - // Use JSON serialization for structured data, with fallback - return quote! { - match result { - Ok(value) => { - // Try JSON serialization first (for structured data) - let (text_content, structured) = match serde_json::to_value(&value) { - Ok(json_value) => { - // Serialize as JSON string for text content - let text = serde_json::to_string(&value) - .unwrap_or_else(|_| format!("{:?}", value)); - (text, Some(json_value)) - } - Err(_) => { - // Fallback to Debug if not serializable - (format!("{:?}", value), None) - } - }; - - Ok(pulseengine_mcp_protocol::CallToolResult { - content: vec![pulseengine_mcp_protocol::Content::text(text_content)], - is_error: Some(false), - structured_content: structured, - _meta: None, - }) + if let syn::Type::Path(type_path) = &**ty + && let Some(segment) = type_path.path.segments.last() + && segment.ident == "Result" + { + // It's already a Result, wrap it properly for the dispatch context + // Use JSON serialization for structured data, with fallback + return quote! { + match result { + Ok(value) => { + // Try JSON serialization first (for structured data) + let (text_content, structured) = match serde_json::to_value(&value) { + Ok(json_value) => { + // Serialize as JSON string for text content + let text = serde_json::to_string(&value) + .unwrap_or_else(|_| format!("{:?}", value)); + (text, Some(json_value)) } - Err(e) => Err(pulseengine_mcp_protocol::Error::internal_error(e.to_string())), - } - }; + Err(_) => { + // Fallback to Debug if not serializable + (format!("{:?}", value), None) + } + }; + + Ok(pulseengine_mcp_protocol::CallToolResult { + content: vec![pulseengine_mcp_protocol::Content::text(text_content)], + is_error: Some(false), + structured_content: structured, + _meta: None, + }) + } + Err(e) => Err(pulseengine_mcp_protocol::Error::internal_error(e.to_string())), } - } + }; } // Not a Result, wrap it with JSON serialization (preferred) or Display formatting diff --git a/mcp-macros/tests/advanced_features.rs b/mcp-macros/tests/advanced_features.rs index 3105318e..3b85ca73 100644 --- a/mcp-macros/tests/advanced_features.rs +++ b/mcp-macros/tests/advanced_features.rs @@ -98,13 +98,13 @@ async fn test_complex_parameters() { let result = server.call_tool(request).await; assert!(result.is_ok()); - if let Ok(result) = result { - if let Some(pulseengine_mcp_protocol::Content::Text { text, .. }) = result.content.first() { - assert!(text.contains("Text: Hello")); - assert!(text.contains("Number: 42")); - assert!(text.contains("Flag: true")); - assert!(text.contains("List length: 3")); - } + if let Ok(result) = result + && let Some(pulseengine_mcp_protocol::Content::Text { text, .. }) = result.content.first() + { + assert!(text.contains("Text: Hello")); + assert!(text.contains("Number: 42")); + assert!(text.contains("Flag: true")); + assert!(text.contains("List length: 3")); } } diff --git a/mcp-macros/tests/dual_pattern_test.rs b/mcp-macros/tests/dual_pattern_test.rs index 67526779..0ed4ec4a 100644 --- a/mcp-macros/tests/dual_pattern_test.rs +++ b/mcp-macros/tests/dual_pattern_test.rs @@ -250,15 +250,14 @@ async fn test_multi_param_still_works() { result.err() ); - if let Ok(call_result) = result { - if let Some(pulseengine_mcp_protocol::Content::Text { text, .. }) = + if let Ok(call_result) = result + && let Some(pulseengine_mcp_protocol::Content::Text { text, .. }) = call_result.content.first() - { - assert!(text.contains("Alice")); - assert!(text.contains("30")); - assert!(text.contains("true")); - println!("✅ Multi-parameter tool works correctly!"); - } + { + assert!(text.contains("Alice")); + assert!(text.contains("30")); + assert!(text.contains("true")); + println!("✅ Multi-parameter tool works correctly!"); } } diff --git a/mcp-protocol/src/validation.rs b/mcp-protocol/src/validation.rs index b38618f1..a497b80d 100644 --- a/mcp-protocol/src/validation.rs +++ b/mcp-protocol/src/validation.rs @@ -133,18 +133,17 @@ impl Validator { /// Returns an error if required arguments are missing from the provided arguments pub fn validate_tool_arguments(args: &HashMap, schema: &Value) -> Result<()> { // Basic validation - check required properties if defined - if let Some(schema_obj) = schema.as_object() { - if let Some(_properties) = schema_obj.get("properties").and_then(|p| p.as_object()) { - if let Some(required) = schema_obj.get("required").and_then(|r| r.as_array()) { - for req_field in required { - if let Some(field_name) = req_field.as_str() { - if !args.contains_key(field_name) { - return Err(Error::validation_error(format!( - "Required argument '{field_name}' is missing" - ))); - } - } - } + if let Some(schema_obj) = schema.as_object() + && let Some(_properties) = schema_obj.get("properties").and_then(|p| p.as_object()) + && let Some(required) = schema_obj.get("required").and_then(|r| r.as_array()) + { + for req_field in required { + if let Some(field_name) = req_field.as_str() + && !args.contains_key(field_name) + { + return Err(Error::validation_error(format!( + "Required argument '{field_name}' is missing" + ))); } } } diff --git a/mcp-transport/src/http.rs b/mcp-transport/src/http.rs index 12fcba39..6a50609c 100644 --- a/mcp-transport/src/http.rs +++ b/mcp-transport/src/http.rs @@ -358,12 +358,11 @@ async fn handle_post( // Validate message let message_json = serde_json::to_string(&message).map_err(|_| StatusCode::BAD_REQUEST)?; - if state.config.validate_messages { - if let Err(e) = validate_message_string(&message_json, Some(state.config.max_message_size)) - { - warn!("Message validation failed: {}", e); - return Err(StatusCode::BAD_REQUEST); - } + if state.config.validate_messages + && let Err(e) = validate_message_string(&message_json, Some(state.config.max_message_size)) + { + warn!("Message validation failed: {}", e); + return Err(StatusCode::BAD_REQUEST); } // Parse and process message diff --git a/mcp-transport/src/stdio.rs b/mcp-transport/src/stdio.rs index 4f5a757a..97f936df 100644 --- a/mcp-transport/src/stdio.rs +++ b/mcp-transport/src/stdio.rs @@ -85,22 +85,22 @@ impl StdioTransport { stdout: &mut tokio::io::Stdout, ) -> Result<(), TransportError> { // Validate message according to MCP spec - if self.config.validate_messages { - if let Err(e) = validate_message_string(line, Some(self.config.max_message_size)) { - warn!("Message validation failed: {}", e); + if self.config.validate_messages + && let Err(e) = validate_message_string(line, Some(self.config.max_message_size)) + { + warn!("Message validation failed: {}", e); - // Try to extract ID for error response - let request_id = extract_id_from_malformed(line); - let error_response = create_error_response( - pulseengine_mcp_protocol::Error::invalid_request(format!( - "Message validation failed: {e}" - )), - request_id, - ); + // Try to extract ID for error response + let request_id = extract_id_from_malformed(line); + let error_response = create_error_response( + pulseengine_mcp_protocol::Error::invalid_request(format!( + "Message validation failed: {e}" + )), + request_id, + ); - self.send_response(stdout, &error_response).await?; - return Ok(()); - } + self.send_response(stdout, &error_response).await?; + return Ok(()); } debug!("Processing message: {}", line); @@ -188,12 +188,12 @@ impl StdioTransport { line: &str, ) -> Result<(), TransportError> { // Validate outgoing message - if self.config.validate_messages { - if let Err(e) = validate_message_string(line, Some(self.config.max_message_size)) { - return Err(TransportError::Protocol(format!( - "Outgoing message validation failed: {e}" - ))); - } + if self.config.validate_messages + && let Err(e) = validate_message_string(line, Some(self.config.max_message_size)) + { + return Err(TransportError::Protocol(format!( + "Outgoing message validation failed: {e}" + ))); } debug!("Sending response: {}", line); diff --git a/mcp-transport/src/validation.rs b/mcp-transport/src/validation.rs index a048f7ee..fac8ade2 100644 --- a/mcp-transport/src/validation.rs +++ b/mcp-transport/src/validation.rs @@ -35,13 +35,13 @@ pub fn validate_message_string( } // Check message size limit - if let Some(max) = max_size { - if message.len() > max { - return Err(ValidationError::MessageTooLarge { - size: message.len(), - max, - }); - } + if let Some(max) = max_size + && message.len() > max + { + return Err(ValidationError::MessageTooLarge { + size: message.len(), + max, + }); } // UTF-8 validation is implicit in Rust strings, but we validate the bytes @@ -116,12 +116,11 @@ pub fn extract_id_from_malformed(text: &str) -> Option(text) { - if let Some(obj) = value.as_object() { - if let Some(id) = obj.get("id") { - return NumberOrString::from_json_value(id.clone()); - } - } + if let Ok(value) = serde_json::from_str::(text) + && let Some(obj) = value.as_object() + && let Some(id) = obj.get("id") + { + return NumberOrString::from_json_value(id.clone()); } // Try regex-based extraction as fallback @@ -169,25 +168,24 @@ fn extract_id_with_regex(text: &str) -> Option { ]; for pattern in &patterns { - if let Ok(re) = Regex::new(pattern) { - if let Some(captures) = re.captures(text) { - if let Some(id_str) = captures.get(1) { - let id_text = id_str.as_str(); - - // Try to parse as number first - if let Ok(num) = id_text.parse::() { - return Some(Value::Number(num.into())); - } - - // Check for null - if id_text == "null" { - return Some(Value::Null); - } - - // Default to string - return Some(Value::String(id_text.to_string())); - } + if let Ok(re) = Regex::new(pattern) + && let Some(captures) = re.captures(text) + && let Some(id_str) = captures.get(1) + { + let id_text = id_str.as_str(); + + // Try to parse as number first + if let Ok(num) = id_text.parse::() { + return Some(Value::Number(num.into())); } + + // Check for null + if id_text == "null" { + return Some(Value::Null); + } + + // Default to string + return Some(Value::String(id_text.to_string())); } } diff --git a/mcp-transport/src/validation_tests.rs b/mcp-transport/src/validation_tests.rs index 2fc9bcba..50e42567 100644 --- a/mcp-transport/src/validation_tests.rs +++ b/mcp-transport/src/validation_tests.rs @@ -86,8 +86,7 @@ mod tests { // Note: Rust's String type actually ensures valid UTF-8, // so this test may pass. In practice, invalid UTF-8 would // come from external sources (network, files, etc.) - if result.is_err() { - let error = result.unwrap_err(); + if let Err(error) = result { assert!(error.to_string().contains("not valid UTF-8")); } } diff --git a/pulseengine-auth/src/audit.rs b/pulseengine-auth/src/audit.rs index 23d90c21..9fc7d262 100644 --- a/pulseengine-auth/src/audit.rs +++ b/pulseengine-auth/src/audit.rs @@ -394,15 +394,15 @@ impl AuditLogger { .log_file .with_extension(format!("log.{}", i + 1)); - if old_file.exists() { - if let Err(e) = fs::rename(&old_file, &new_file).await { - warn!( - "Failed to rotate log file {} to {}: {}", - old_file.display(), - new_file.display(), - e - ); - } + if old_file.exists() + && let Err(e) = fs::rename(&old_file, &new_file).await + { + warn!( + "Failed to rotate log file {} to {}: {}", + old_file.display(), + new_file.display(), + e + ); } } @@ -418,14 +418,14 @@ impl AuditLogger { .config .log_file .with_extension(format!("log.{}", self.config.max_files)); - if oldest_file.exists() { - if let Err(e) = fs::remove_file(&oldest_file).await { - warn!( - "Failed to remove oldest log file {}: {}", - oldest_file.display(), - e - ); - } + if oldest_file.exists() + && let Err(e) = fs::remove_file(&oldest_file).await + { + warn!( + "Failed to remove oldest log file {}: {}", + oldest_file.display(), + e + ); } debug!( diff --git a/pulseengine-auth/src/config.rs b/pulseengine-auth/src/config.rs index dedafaa6..a1d5ac8f 100644 --- a/pulseengine-auth/src/config.rs +++ b/pulseengine-auth/src/config.rs @@ -136,10 +136,10 @@ impl AuthConfig { /// Get the default storage path for an application fn get_app_storage_path(app_name: &str) -> PathBuf { // Check for environment variable override first - if let Ok(app_name_override) = std::env::var("PULSEENGINE_MCP_APP_NAME") { - if !app_name_override.trim().is_empty() { - return Self::build_storage_path(&app_name_override); - } + if let Ok(app_name_override) = std::env::var("PULSEENGINE_MCP_APP_NAME") + && !app_name_override.trim().is_empty() + { + return Self::build_storage_path(&app_name_override); } Self::build_storage_path(app_name) diff --git a/pulseengine-auth/src/consent/manager.rs b/pulseengine-auth/src/consent/manager.rs index 8689fe8f..f25ca267 100644 --- a/pulseengine-auth/src/consent/manager.rs +++ b/pulseengine-auth/src/consent/manager.rs @@ -445,23 +445,22 @@ impl ConsentManager { let subject_prefix = format!("consent:{subject_id}:"); for key in all_keys { - if key.starts_with(&subject_prefix) { - if let Ok(consent_data) = self.storage.get(&key).await { - if let Ok(record) = serde_json::from_str::(&consent_data) { - consents.insert(record.consent_type.clone(), record.status.clone()); - - if record.status == ConsentStatus::Pending { - pending_requests += 1; - } - - if record.is_expired() { - expired_consents += 1; - } - - if record.updated_at > last_updated { - last_updated = record.updated_at; - } - } + if key.starts_with(&subject_prefix) + && let Ok(consent_data) = self.storage.get(&key).await + && let Ok(record) = serde_json::from_str::(&consent_data) + { + consents.insert(record.consent_type.clone(), record.status.clone()); + + if record.status == ConsentStatus::Pending { + pending_requests += 1; + } + + if record.is_expired() { + expired_consents += 1; + } + + if record.updated_at > last_updated { + last_updated = record.updated_at; } } } @@ -493,26 +492,25 @@ impl ConsentManager { .map_err(|e| ConsentError::StorageError(e.to_string()))?; for key in all_keys { - if key.starts_with("consent:") { - if let Ok(consent_data) = self.storage.get(&key).await { - if let Ok(record) = serde_json::from_str::(&consent_data) { - if record.is_expired() && record.updated_at < cutoff_date { - self.storage - .delete(&key) - .await - .map_err(|e| ConsentError::StorageError(e.to_string()))?; - - // Remove from cache - { - let mut cache = self.consent_cache.write().await; - cache.remove(&record.id); - } - - cleaned_count += 1; - debug!("Cleaned up expired consent record: {}", record.id); - } - } + if key.starts_with("consent:") + && let Ok(consent_data) = self.storage.get(&key).await + && let Ok(record) = serde_json::from_str::(&consent_data) + && record.is_expired() + && record.updated_at < cutoff_date + { + self.storage + .delete(&key) + .await + .map_err(|e| ConsentError::StorageError(e.to_string()))?; + + // Remove from cache + { + let mut cache = self.consent_cache.write().await; + cache.remove(&record.id); } + + cleaned_count += 1; + debug!("Cleaned up expired consent record: {}", record.id); } } diff --git a/pulseengine-auth/src/manager.rs b/pulseengine-auth/src/manager.rs index 4db3bc99..79c6b4ce 100644 --- a/pulseengine-auth/src/manager.rs +++ b/pulseengine-auth/src/manager.rs @@ -424,25 +424,25 @@ impl AuthenticationManager { } // Check role-based rate limiting - if let Ok(is_rate_limited) = self.check_role_rate_limit(&key.role, client_ip).await { - if is_rate_limited { - self.record_failed_attempt(client_ip).await; - - // Log role-based rate limiting - let audit_event = events::auth_failure( - client_ip, - &format!( - "Role-based rate limit exceeded for role {}", - self.get_role_key(&key.role) - ), - ); - let _ = self.audit_logger.log(audit_event).await; + if let Ok(is_rate_limited) = self.check_role_rate_limit(&key.role, client_ip).await + && is_rate_limited + { + self.record_failed_attempt(client_ip).await; - return Err(AuthError::Failed(format!( - "Rate limit exceeded for role {}", + // Log role-based rate limiting + let audit_event = events::auth_failure( + client_ip, + &format!( + "Role-based rate limit exceeded for role {}", self.get_role_key(&key.role) - ))); - } + ), + ); + let _ = self.audit_logger.log(audit_event).await; + + return Err(AuthError::Failed(format!( + "Rate limit exceeded for role {}", + self.get_role_key(&key.role) + ))); } // Clear any failed attempts for this IP @@ -542,12 +542,11 @@ impl AuthenticationManager { async fn check_rate_limit(&self, client_ip: &str) -> Option> { let rate_limits = self.rate_limit_state.read().await; - if let Some(state) = rate_limits.get(client_ip) { - if let Some(blocked_until) = state.blocked_until { - if Utc::now() < blocked_until { - return Some(blocked_until); - } - } + if let Some(state) = rate_limits.get(client_ip) + && let Some(blocked_until) = state.blocked_until + && Utc::now() < blocked_until + { + return Some(blocked_until); } None @@ -615,10 +614,10 @@ impl AuthenticationManager { } // Check if key has expired - if let Some(expires_at) = key.expires_at { - if Utc::now() > expires_at { - return Err("API key has expired".to_string()); - } + if let Some(expires_at) = key.expires_at + && Utc::now() > expires_at + { + return Err("API key has expired".to_string()); } // Check IP whitelist @@ -680,10 +679,10 @@ impl AuthenticationManager { for state in rate_limits.values() { stats.total_failed_attempts += state.failed_attempts as u64; - if let Some(blocked_until) = state.blocked_until { - if now < blocked_until { - stats.currently_blocked_ips += 1; - } + if let Some(blocked_until) = state.blocked_until + && now < blocked_until + { + stats.currently_blocked_ips += 1; } } @@ -704,14 +703,14 @@ impl AuthenticationManager { role_statistics.total_requests += state.total_requests; // Check if any IP is in cooldown for this role - if let Some(cooldown_end) = state.cooldown_ends_at { - if now < cooldown_end { - role_statistics.in_cooldown = true; - if role_statistics.cooldown_ends_at.is_none() - || cooldown_end > role_statistics.cooldown_ends_at.unwrap() - { - role_statistics.cooldown_ends_at = Some(cooldown_end); - } + if let Some(cooldown_end) = state.cooldown_ends_at + && now < cooldown_end + { + role_statistics.in_cooldown = true; + if role_statistics.cooldown_ends_at.is_none() + || cooldown_end > role_statistics.cooldown_ends_at.unwrap() + { + role_statistics.cooldown_ends_at = Some(cooldown_end); } } } @@ -731,10 +730,10 @@ impl AuthenticationManager { let initial_count = rate_limits.len(); rate_limits.retain(|_ip, state| { // Keep if blocked and still in block period - if let Some(blocked_until) = state.blocked_until { - if now < blocked_until { - return true; - } + if let Some(blocked_until) = state.blocked_until + && now < blocked_until + { + return true; } // Keep if within the tracking window @@ -936,17 +935,17 @@ impl AuthenticationManager { let initial_count = ip_states.len(); ip_states.retain(|_ip, state| { // Keep if in cooldown - if let Some(cooldown_end) = state.cooldown_ends_at { - if now < cooldown_end { - return true; - } + if let Some(cooldown_end) = state.cooldown_ends_at + && now < cooldown_end + { + return true; } // Keep if window started recently - if let Some(window_start) = state.last_window_start { - if now.signed_duration_since(window_start) < cleanup_threshold { - return true; - } + if let Some(window_start) = state.last_window_start + && now.signed_duration_since(window_start) < cleanup_threshold + { + return true; } // Remove old inactive entries diff --git a/pulseengine-auth/src/manager_vault.rs b/pulseengine-auth/src/manager_vault.rs index 82588e61..51957e4e 100644 --- a/pulseengine-auth/src/manager_vault.rs +++ b/pulseengine-auth/src/manager_vault.rs @@ -90,10 +90,10 @@ impl VaultAuthenticationManager { } // Try to get additional configuration from vault - if let Some(vault) = &vault_integration { - if let Ok(vault_config) = vault.get_api_config().await { - Self::apply_vault_config(&mut auth_config, &vault_config); - } + if let Some(vault) = &vault_integration + && let Ok(vault_config) = vault.get_api_config().await + { + Self::apply_vault_config(&mut auth_config, &vault_config); } // Use provided validation config or try to create from vault config @@ -129,31 +129,31 @@ impl VaultAuthenticationManager { /// Apply vault configuration to auth config fn apply_vault_config(auth_config: &mut AuthConfig, vault_config: &HashMap) { - if let Some(timeout) = vault_config.get("PULSEENGINE_MCP_SESSION_TIMEOUT") { - if let Ok(timeout_secs) = timeout.parse::() { - auth_config.session_timeout_secs = timeout_secs; - debug!( - "Applied vault config: session_timeout_secs = {}", - timeout_secs - ); - } + if let Some(timeout) = vault_config.get("PULSEENGINE_MCP_SESSION_TIMEOUT") + && let Ok(timeout_secs) = timeout.parse::() + { + auth_config.session_timeout_secs = timeout_secs; + debug!( + "Applied vault config: session_timeout_secs = {}", + timeout_secs + ); } - if let Some(max_attempts) = vault_config.get("PULSEENGINE_MCP_MAX_FAILED_ATTEMPTS") { - if let Ok(attempts) = max_attempts.parse::() { - auth_config.max_failed_attempts = attempts; - debug!("Applied vault config: max_failed_attempts = {}", attempts); - } + if let Some(max_attempts) = vault_config.get("PULSEENGINE_MCP_MAX_FAILED_ATTEMPTS") + && let Ok(attempts) = max_attempts.parse::() + { + auth_config.max_failed_attempts = attempts; + debug!("Applied vault config: max_failed_attempts = {}", attempts); } - if let Some(rate_limit) = vault_config.get("PULSEENGINE_MCP_RATE_LIMIT_WINDOW") { - if let Ok(window_secs) = rate_limit.parse::() { - auth_config.rate_limit_window_secs = window_secs; - debug!( - "Applied vault config: rate_limit_window_secs = {}", - window_secs - ); - } + if let Some(rate_limit) = vault_config.get("PULSEENGINE_MCP_RATE_LIMIT_WINDOW") + && let Ok(window_secs) = rate_limit.parse::() + { + auth_config.rate_limit_window_secs = window_secs; + debug!( + "Applied vault config: rate_limit_window_secs = {}", + window_secs + ); } if let Some(storage_path) = vault_config.get("PULSEENGINE_MCP_STORAGE_PATH") { diff --git a/pulseengine-auth/src/middleware/mcp_auth.rs b/pulseengine-auth/src/middleware/mcp_auth.rs index dd591ca1..d5f5e530 100644 --- a/pulseengine-auth/src/middleware/mcp_auth.rs +++ b/pulseengine-auth/src/middleware/mcp_auth.rs @@ -200,12 +200,11 @@ impl McpAuthMiddleware { let mut context = McpRequestContext::new(id); // Extract client IP if available - if let Some(headers) = headers { - if let Some(ip_header) = &self.config.client_ip_header { - if let Some(client_ip) = headers.get(ip_header) { - context = context.with_client_ip(client_ip.clone()); - } - } + if let Some(headers) = headers + && let Some(ip_header) = &self.config.client_ip_header + && let Some(client_ip) = headers.get(ip_header) + { + context = context.with_client_ip(client_ip.clone()); } // Check if authentication is required for this method diff --git a/pulseengine-auth/src/middleware/session_middleware.rs b/pulseengine-auth/src/middleware/session_middleware.rs index 9d39178b..581f43cc 100644 --- a/pulseengine-auth/src/middleware/session_middleware.rs +++ b/pulseengine-auth/src/middleware/session_middleware.rs @@ -202,12 +202,11 @@ impl SessionMiddleware { let mut session_context = SessionRequestContext::new(base_context.clone()); // Extract client IP - if let Some(headers) = headers { - if let Some(ip_header) = &self.config.auth_config.client_ip_header { - if let Some(client_ip) = headers.get(ip_header) { - base_context = base_context.with_client_ip(client_ip.clone()); - } - } + if let Some(headers) = headers + && let Some(ip_header) = &self.config.auth_config.client_ip_header + && let Some(client_ip) = headers.get(ip_header) + { + base_context = base_context.with_client_ip(client_ip.clone()); } // Check if this method requires authentication/sessions @@ -277,18 +276,17 @@ impl SessionMiddleware { ) -> Result<(AuthContext, String, Option), SessionMiddlewareError> { if let Some(headers) = headers { // Try JWT authentication first - if self.config.enable_jwt_auth { - if let Ok((auth_context, method)) = self.try_jwt_authentication(headers).await { - return Ok((auth_context, method, None)); - } + if self.config.enable_jwt_auth + && let Ok((auth_context, method)) = self.try_jwt_authentication(headers).await + { + return Ok((auth_context, method, None)); } // Try session ID authentication - if self.config.enable_sessions { - if let Ok((auth_context, session)) = self.try_session_authentication(headers).await - { - return Ok((auth_context, "Session".to_string(), Some(session))); - } + if self.config.enable_sessions + && let Ok((auth_context, session)) = self.try_session_authentication(headers).await + { + return Ok((auth_context, "Session".to_string(), Some(session))); } // Fall back to traditional API key authentication @@ -307,12 +305,12 @@ impl SessionMiddleware { &self, headers: &HashMap, ) -> Result<(AuthContext, String), SessionMiddlewareError> { - if let Some(auth_header) = headers.get(&self.config.jwt_header_name) { - if auth_header.starts_with("Bearer ") { - let token = &auth_header[7..]; - let auth_context = self.session_manager.validate_jwt_token(token).await?; - return Ok((auth_context, "JWT".to_string())); - } + if let Some(auth_header) = headers.get(&self.config.jwt_header_name) + && auth_header.starts_with("Bearer ") + { + let token = &auth_header[7..]; + let auth_context = self.session_manager.validate_jwt_token(token).await?; + return Ok((auth_context, "JWT".to_string())); } Err(SessionMiddlewareError::AuthError( @@ -341,17 +339,17 @@ impl SessionMiddleware { headers: &HashMap, ) -> Result<(AuthContext, String), SessionMiddlewareError> { // Try Authorization header - if let Some(auth_header) = headers.get(&self.config.auth_config.auth_header_name) { - if let Ok((auth_context, method)) = self.parse_auth_header(auth_header).await { - return Ok((auth_context, method)); - } + if let Some(auth_header) = headers.get(&self.config.auth_config.auth_header_name) + && let Ok((auth_context, method)) = self.parse_auth_header(auth_header).await + { + return Ok((auth_context, method)); } // Try X-API-Key header - if let Some(api_key) = headers.get("X-API-Key") { - if let Ok(auth_context) = self.validate_api_key(api_key).await { - return Ok((auth_context, "X-API-Key".to_string())); - } + if let Some(api_key) = headers.get("X-API-Key") + && let Ok(auth_context) = self.validate_api_key(api_key).await + { + return Ok((auth_context, "X-API-Key".to_string())); } Err(SessionMiddlewareError::AuthError( diff --git a/pulseengine-auth/src/oauth/bearer.rs b/pulseengine-auth/src/oauth/bearer.rs index 0bb22b65..8447974e 100644 --- a/pulseengine-auth/src/oauth/bearer.rs +++ b/pulseengine-auth/src/oauth/bearer.rs @@ -424,9 +424,9 @@ mod tests { let result = validate_bearer_token(&auth_header, &config); // When no expected audience is configured, validation should succeed // Note: This may fail if jsonwebtoken still requires audience - we test that the config works - if result.is_err() { + if let Err(err) = result { // If it fails, ensure it's not a signature or expiration error - match result.unwrap_err() { + match err { BearerError::ExpiredToken => panic!("Should not be expired"), BearerError::InvalidToken(msg) if msg.contains("signature") => { panic!("Signature should be valid") diff --git a/pulseengine-auth/src/permissions/mcp_permissions.rs b/pulseengine-auth/src/permissions/mcp_permissions.rs index 3f9e42fa..58e21763 100644 --- a/pulseengine-auth/src/permissions/mcp_permissions.rs +++ b/pulseengine-auth/src/permissions/mcp_permissions.rs @@ -337,14 +337,14 @@ impl PermissionChecker { // Check custom rules first for rule in &self.config.custom_rules { - if let Permission::UseTool(rule_tool) = &rule.permission { - if rule_tool == tool_name { - for role in &auth_context.roles { - if rule.applies_to_role(role) { - match rule.action { - PermissionAction::Allow => return true, - PermissionAction::Deny => return false, - } + if let Permission::UseTool(rule_tool) = &rule.permission + && rule_tool == tool_name + { + for role in &auth_context.roles { + if rule.applies_to_role(role) { + match rule.action { + PermissionAction::Allow => return true, + PermissionAction::Deny => return false, } } } @@ -373,13 +373,13 @@ impl PermissionChecker { } // Check tool category permissions - if let Some(category) = self.extract_tool_category(tool_name) { - if let Some(allowed_roles) = self.config.tools.category_permissions.get(&category) { - return auth_context - .roles - .iter() - .any(|role| allowed_roles.contains(role)); - } + if let Some(category) = self.extract_tool_category(tool_name) + && let Some(allowed_roles) = self.config.tools.category_permissions.get(&category) + { + return auth_context + .roles + .iter() + .any(|role| allowed_roles.contains(role)); } // Fall back to default action @@ -398,14 +398,14 @@ impl PermissionChecker { // Check custom rules first for rule in &self.config.custom_rules { - if let Permission::UseResource(rule_resource) = &rule.permission { - if self.matches_resource_pattern(rule_resource, resource_uri) { - for role in &auth_context.roles { - if rule.applies_to_role(role) { - match rule.action { - PermissionAction::Allow => return true, - PermissionAction::Deny => return false, - } + if let Permission::UseResource(rule_resource) = &rule.permission + && self.matches_resource_pattern(rule_resource, resource_uri) + { + for role in &auth_context.roles { + if rule.applies_to_role(role) { + match rule.action { + PermissionAction::Allow => return true, + PermissionAction::Deny => return false, } } } @@ -443,13 +443,13 @@ impl PermissionChecker { } // Check resource category permissions - if let Some(category) = self.extract_resource_category(resource_uri) { - if let Some(allowed_roles) = self.config.resources.category_permissions.get(&category) { - return auth_context - .roles - .iter() - .any(|role| allowed_roles.contains(role)); - } + if let Some(category) = self.extract_resource_category(resource_uri) + && let Some(allowed_roles) = self.config.resources.category_permissions.get(&category) + { + return auth_context + .roles + .iter() + .any(|role| allowed_roles.contains(role)); } // Fall back to default action @@ -474,14 +474,14 @@ impl PermissionChecker { // Check for subscription-specific rules for rule in &self.config.custom_rules { - if let Permission::Subscribe(rule_resource) = &rule.permission { - if self.matches_resource_pattern(rule_resource, resource_uri) { - for role in &auth_context.roles { - if rule.applies_to_role(role) { - match rule.action { - PermissionAction::Allow => return true, - PermissionAction::Deny => return false, - } + if let Permission::Subscribe(rule_resource) = &rule.permission + && self.matches_resource_pattern(rule_resource, resource_uri) + { + for role in &auth_context.roles { + if rule.applies_to_role(role) { + match rule.action { + PermissionAction::Allow => return true, + PermissionAction::Deny => return false, } } } diff --git a/pulseengine-auth/src/security/mod.rs b/pulseengine-auth/src/security/mod.rs index 440318b9..7678d6bb 100644 --- a/pulseengine-auth/src/security/mod.rs +++ b/pulseengine-auth/src/security/mod.rs @@ -195,13 +195,12 @@ mod tests { .await; // If validation passes, sanitize the input - if validation_result.is_ok() { - if let Some(args) = params.get("arguments") { - if let Some(input) = args.get("input").and_then(|v| v.as_str()) { - let sanitized = sanitizer.sanitize_string(input); - assert!(sanitized != input || sanitized.is_empty()); - } - } + if validation_result.is_ok() + && let Some(args) = params.get("arguments") + && let Some(input) = args.get("input").and_then(|v| v.as_str()) + { + let sanitized = sanitizer.sanitize_string(input); + assert!(sanitized != input || sanitized.is_empty()); } // If validation fails, that's also acceptable for strict config } diff --git a/pulseengine-auth/src/security/request_security.rs b/pulseengine-auth/src/security/request_security.rs index fc8643a5..84128e37 100644 --- a/pulseengine-auth/src/security/request_security.rs +++ b/pulseengine-auth/src/security/request_security.rs @@ -672,25 +672,25 @@ impl RequestSecurityValidator { } // Apply method-specific restrictions based on user role - if let Some(restricted_methods) = self.get_restricted_methods_for_user(auth_context) { - if restricted_methods.iter().any(|m| m == method) { - self.log_violation(SecurityViolation { - violation_type: SecurityViolationType::UnauthorizedMethod, - severity: SecuritySeverity::Critical, - description: format!( - "User {} attempted to access restricted method: {}", - auth_context.user_id.as_deref().unwrap_or("unknown"), - method - ), - field: Some("method".to_string()), - value: Some(method.to_string()), - timestamp: chrono::Utc::now(), - }); - - return Err(SecurityValidationError::UnsupportedMethod { - method: method.to_string(), - }); - } + if let Some(restricted_methods) = self.get_restricted_methods_for_user(auth_context) + && restricted_methods.iter().any(|m| m == method) + { + self.log_violation(SecurityViolation { + violation_type: SecurityViolationType::UnauthorizedMethod, + severity: SecuritySeverity::Critical, + description: format!( + "User {} attempted to access restricted method: {}", + auth_context.user_id.as_deref().unwrap_or("unknown"), + method + ), + field: Some("method".to_string()), + value: Some(method.to_string()), + timestamp: chrono::Utc::now(), + }); + + return Err(SecurityValidationError::UnsupportedMethod { + method: method.to_string(), + }); } // Apply enhanced injection detection for anonymous users diff --git a/pulseengine-auth/src/storage.rs b/pulseengine-auth/src/storage.rs index b3f64cbb..6ec730dc 100644 --- a/pulseengine-auth/src/storage.rs +++ b/pulseengine-auth/src/storage.rs @@ -178,35 +178,34 @@ impl FileStorage { { use std::os::unix::fs::MetadataExt; - if let Some(parent) = path.parent() { - if parent.exists() { - let metadata = fs::metadata(parent).await?; - - // Check if this is a network filesystem (basic check) - let _dev = metadata.dev(); - - // On many Unix systems, network filesystems have device IDs that indicate remote storage - // This is a basic check - in production you might want more sophisticated detection - if let Ok(mount_info) = fs::read_to_string("/proc/mounts").await { - let path_str = parent.to_string_lossy(); - for line in mount_info.lines() { - let parts: Vec<&str> = line.split_whitespace().collect(); - if parts.len() >= 3 { - let mount_point = parts[1]; - let fs_type = parts[2]; - - if path_str.starts_with(mount_point) { - // Check for network filesystem types - match fs_type { - "nfs" | "nfs4" | "cifs" | "smb" | "smbfs" - | "fuse.sshfs" => { - return Err(StorageError::Permission(format!( - "Storage path {} is on insecure network filesystem: {}", - path_str, fs_type - ))); - } - _ => {} + if let Some(parent) = path.parent() + && parent.exists() + { + let metadata = fs::metadata(parent).await?; + + // Check if this is a network filesystem (basic check) + let _dev = metadata.dev(); + + // On many Unix systems, network filesystems have device IDs that indicate remote storage + // This is a basic check - in production you might want more sophisticated detection + if let Ok(mount_info) = fs::read_to_string("/proc/mounts").await { + let path_str = parent.to_string_lossy(); + for line in mount_info.lines() { + let parts: Vec<&str> = line.split_whitespace().collect(); + if parts.len() >= 3 { + let mount_point = parts[1]; + let fs_type = parts[2]; + + if path_str.starts_with(mount_point) { + // Check for network filesystem types + match fs_type { + "nfs" | "nfs4" | "cifs" | "smb" | "smbfs" | "fuse.sshfs" => { + return Err(StorageError::Permission(format!( + "Storage path {} is on insecure network filesystem: {}", + path_str, fs_type + ))); } + _ => {} } } } @@ -382,17 +381,16 @@ impl FileStorage { while let Some(entry) = entries.next_entry().await? { let path = entry.path(); - if let Some(filename) = path.file_name().and_then(|n| n.to_str()) { - if filename.starts_with(&format!("{}.backup_", filename_stem)) { - if let Ok(metadata) = entry.metadata().await { - backups.push(( - path, - metadata - .modified() - .unwrap_or(std::time::SystemTime::UNIX_EPOCH), - )); - } - } + if let Some(filename) = path.file_name().and_then(|n| n.to_str()) + && filename.starts_with(&format!("{}.backup_", filename_stem)) + && let Ok(metadata) = entry.metadata().await + { + backups.push(( + path, + metadata + .modified() + .unwrap_or(std::time::SystemTime::UNIX_EPOCH), + )); } } @@ -453,27 +451,26 @@ impl FileStorage { match inotify.read_events_blocking(&mut buffer) { Ok(events) => { for event in events { - if let Some(name) = event.name { - if name.to_string_lossy().contains("keys") { - warn!( - "Detected unauthorized change to auth storage: {:?} (mask: {:?})", - name, event.mask - ); - - // Verify file permissions haven't been changed - if path.exists() { - #[cfg(unix)] - { - use std::os::unix::fs::PermissionsExt; - if let Ok(metadata) = std::fs::metadata(&path) { - let mode = - metadata.permissions().mode() & 0o777; - if mode != file_permissions { - error!( - "Security violation: File permissions changed from {:o} to {:o}", - file_permissions, mode - ); - } + if let Some(name) = event.name + && name.to_string_lossy().contains("keys") + { + warn!( + "Detected unauthorized change to auth storage: {:?} (mask: {:?})", + name, event.mask + ); + + // Verify file permissions haven't been changed + if path.exists() { + #[cfg(unix)] + { + use std::os::unix::fs::PermissionsExt; + if let Ok(metadata) = std::fs::metadata(&path) { + let mode = metadata.permissions().mode() & 0o777; + if mode != file_permissions { + error!( + "Security violation: File permissions changed from {:o} to {:o}", + file_permissions, mode + ); } } } diff --git a/pulseengine-auth/src/transport/stdio_auth.rs b/pulseengine-auth/src/transport/stdio_auth.rs index 05b3450a..2858aed5 100644 --- a/pulseengine-auth/src/transport/stdio_auth.rs +++ b/pulseengine-auth/src/transport/stdio_auth.rs @@ -59,16 +59,13 @@ impl StdioAuthExtractor { /// Extract authentication from environment variables fn extract_env_auth(&self) -> AuthExtractionResult { - if let Ok(api_key) = std::env::var(&self.config.api_key_env_var) { - if !api_key.is_empty() { - AuthUtils::validate_api_key_format(&api_key)?; - let context = TransportAuthContext::new( - api_key, - "Environment".to_string(), - TransportType::Stdio, - ); - return Ok(Some(context)); - } + if let Ok(api_key) = std::env::var(&self.config.api_key_env_var) + && !api_key.is_empty() + { + AuthUtils::validate_api_key_format(&api_key)?; + let context = + TransportAuthContext::new(api_key, "Environment".to_string(), TransportType::Stdio); + return Ok(Some(context)); } Ok(None) @@ -113,22 +110,20 @@ impl StdioAuthExtractor { } // Try in capabilities - if let Some(capabilities) = client_info.get("capabilities") { - if let Some(auth) = capabilities.get("authentication") { - if let Some(api_key) = auth.get("api_key").and_then(|v| v.as_str()) { - return Some(api_key.to_string()); - } - } + if let Some(capabilities) = client_info.get("capabilities") + && let Some(auth) = capabilities.get("authentication") + && let Some(api_key) = auth.get("api_key").and_then(|v| v.as_str()) + { + return Some(api_key.to_string()); } } // Try in server capabilities/config - if let Some(capabilities) = params.get("capabilities") { - if let Some(auth) = capabilities.get("authentication") { - if let Some(api_key) = auth.get("api_key").and_then(|v| v.as_str()) { - return Some(api_key.to_string()); - } - } + if let Some(capabilities) = params.get("capabilities") + && let Some(auth) = capabilities.get("authentication") + && let Some(api_key) = auth.get("api_key").and_then(|v| v.as_str()) + { + return Some(api_key.to_string()); } None @@ -192,10 +187,10 @@ impl StdioAuthExtractor { _request: &TransportRequest, ) -> TransportAuthContext { // Add process information - if let Ok(current_exe) = std::env::current_exe() { - if let Some(exe_name) = current_exe.file_name().and_then(|n| n.to_str()) { - context = context.with_metadata("process".to_string(), exe_name.to_string()); - } + if let Ok(current_exe) = std::env::current_exe() + && let Some(exe_name) = current_exe.file_name().and_then(|n| n.to_str()) + { + context = context.with_metadata("process".to_string(), exe_name.to_string()); } // Add working directory diff --git a/pulseengine-auth/src/transport/websocket_auth.rs b/pulseengine-auth/src/transport/websocket_auth.rs index 75163b13..e786d2a1 100644 --- a/pulseengine-auth/src/transport/websocket_auth.rs +++ b/pulseengine-auth/src/transport/websocket_auth.rs @@ -99,20 +99,19 @@ impl WebSocketAuthExtractor { if let Some(auth_header) = headers .get("Authorization") .or_else(|| headers.get("authorization")) + && auth_header.starts_with("Bearer ") { - if auth_header.starts_with("Bearer ") { - match AuthUtils::extract_bearer_token(auth_header) { - Ok(token) => { - AuthUtils::validate_api_key_format(&token)?; - let context = TransportAuthContext::new( - token, - "HandshakeHeaders".to_string(), - TransportType::WebSocket, - ); - return Ok(Some(context)); - } - Err(e) => return Err(e), + match AuthUtils::extract_bearer_token(auth_header) { + Ok(token) => { + AuthUtils::validate_api_key_format(&token)?; + let context = TransportAuthContext::new( + token, + "HandshakeHeaders".to_string(), + TransportType::WebSocket, + ); + return Ok(Some(context)); } + Err(e) => return Err(e), } } @@ -128,16 +127,16 @@ impl WebSocketAuthExtractor { } // Try WebSocket-specific headers - if let Some(api_key) = headers.get("Sec-WebSocket-Protocol") { - if let Some(auth_token) = self.extract_from_subprotocol(api_key) { - AuthUtils::validate_api_key_format(&auth_token)?; - let context = TransportAuthContext::new( - auth_token, - "Subprotocol".to_string(), - TransportType::WebSocket, - ); - return Ok(Some(context)); - } + if let Some(api_key) = headers.get("Sec-WebSocket-Protocol") + && let Some(auth_token) = self.extract_from_subprotocol(api_key) + { + AuthUtils::validate_api_key_format(&auth_token)?; + let context = TransportAuthContext::new( + auth_token, + "Subprotocol".to_string(), + TransportType::WebSocket, + ); + return Ok(Some(context)); } Ok(None) @@ -232,12 +231,11 @@ impl WebSocketAuthExtractor { } // Try nested in clientInfo - if let Some(client_info) = params.get("clientInfo") { - if let Some(auth) = client_info.get("authentication") { - if let Some(api_key) = auth.get("api_key").and_then(|v| v.as_str()) { - return Some(api_key.to_string()); - } - } + if let Some(client_info) = params.get("clientInfo") + && let Some(auth) = client_info.get("authentication") + && let Some(api_key) = auth.get("api_key").and_then(|v| v.as_str()) + { + return Some(api_key.to_string()); } } @@ -298,12 +296,11 @@ impl WebSocketAuthExtractor { } // Check subprotocol for auth - if let Some(protocols) = request.get_header("Sec-WebSocket-Protocol") { - if let Some(auth_protocol) = &self.config.auth_subprotocol { - if protocols.contains(auth_protocol) { - return true; - } - } + if let Some(protocols) = request.get_header("Sec-WebSocket-Protocol") + && let Some(auth_protocol) = &self.config.auth_subprotocol + && protocols.contains(auth_protocol) + { + return true; } false diff --git a/pulseengine-auth/src/validation.rs b/pulseengine-auth/src/validation.rs index 54ba4634..d19432cc 100644 --- a/pulseengine-auth/src/validation.rs +++ b/pulseengine-auth/src/validation.rs @@ -46,10 +46,10 @@ pub fn extract_client_ip(headers: &HashMap) -> String { /// Helper function to extract API key from request headers or query parameters pub fn extract_api_key(headers: &HashMap, query: Option<&str>) -> Option { // Try Authorization header with Bearer token - if let Some(auth_header) = headers.get("authorization") { - if let Some(token) = auth_header.strip_prefix("Bearer ") { - return Some(token.to_string()); - } + if let Some(auth_header) = headers.get("authorization") + && let Some(token) = auth_header.strip_prefix("Bearer ") + { + return Some(token.to_string()); } // Try X-API-Key header @@ -60,10 +60,10 @@ pub fn extract_api_key(headers: &HashMap, query: Option<&str>) - // Try query parameter if let Some(query_string) = query { for param in query_string.split('&') { - if let Some((key, value)) = param.split_once('=') { - if key == "api_key" { - return Some(urlencoding::decode(value).unwrap_or_default().to_string()); - } + if let Some((key, value)) = param.split_once('=') + && key == "api_key" + { + return Some(urlencoding::decode(value).unwrap_or_default().to_string()); } } } diff --git a/pulseengine-auth/src/vault/infisical.rs b/pulseengine-auth/src/vault/infisical.rs index 8ef1def6..4050ee4b 100644 --- a/pulseengine-auth/src/vault/infisical.rs +++ b/pulseengine-auth/src/vault/infisical.rs @@ -327,10 +327,10 @@ impl VaultClient for InfisicalClient { let mut tags = HashMap::new(); // Include secret comment as a tag if present - if let Some(comment) = &secret.secret_comment { - if !comment.is_empty() { - tags.insert("comment".to_string(), comment.clone()); - } + if let Some(comment) = &secret.secret_comment + && !comment.is_empty() + { + tags.insert("comment".to_string(), comment.clone()); } // Include version as a tag if present @@ -639,10 +639,10 @@ impl InfisicalClient { let mut tags = HashMap::new(); // Include secret comment as a tag if present - if let Some(comment) = &secret.secret_comment { - if !comment.is_empty() { - tags.insert("comment".to_string(), comment.clone()); - } + if let Some(comment) = &secret.secret_comment + && !comment.is_empty() + { + tags.insert("comment".to_string(), comment.clone()); } // Include version as a tag @@ -672,10 +672,10 @@ impl InfisicalClient { // Try to get the current secret and extract version info match self.get_secret_with_metadata(name).await { Ok((_, metadata)) => { - if let Some(version_str) = metadata.version { - if let Ok(version) = version_str.parse::() { - return Ok(vec![version]); - } + if let Some(version_str) = metadata.version + && let Ok(version) = version_str.parse::() + { + return Ok(vec![version]); } Ok(vec![1]) // Default to version 1 if no version info } diff --git a/pulseengine-auth/src/vault/mod.rs b/pulseengine-auth/src/vault/mod.rs index 317389d0..f62e9a9e 100644 --- a/pulseengine-auth/src/vault/mod.rs +++ b/pulseengine-auth/src/vault/mod.rs @@ -171,10 +171,10 @@ impl VaultIntegration { // Check cache first { let cache = self.secret_cache.read().await; - if let Some((value, timestamp)) = cache.get(name) { - if timestamp.elapsed() < self.cache_ttl { - return Ok(value.clone()); - } + if let Some((value, timestamp)) = cache.get(name) + && timestamp.elapsed() < self.cache_ttl + { + return Ok(value.clone()); } } diff --git a/pulseengine-logging/src/alerting.rs b/pulseengine-logging/src/alerting.rs index f259b99b..7126d02e 100644 --- a/pulseengine-logging/src/alerting.rs +++ b/pulseengine-logging/src/alerting.rs @@ -595,10 +595,10 @@ impl AlertManager { while let Some(notification) = rx.recv().await { for channel_id in ¬ification.channels { - if let Some(channel) = config.channels.get(channel_id) { - if let Err(e) = Self::send_notification(channel, ¬ification).await { - error!("Failed to send notification to {}: {}", channel_id, e); - } + if let Some(channel) = config.channels.get(channel_id) + && let Err(e) = Self::send_notification(channel, ¬ification).await + { + error!("Failed to send notification to {}: {}", channel_id, e); } } } diff --git a/pulseengine-logging/src/correlation.rs b/pulseengine-logging/src/correlation.rs index 13620fe3..7507f4a6 100644 --- a/pulseengine-logging/src/correlation.rs +++ b/pulseengine-logging/src/correlation.rs @@ -460,10 +460,10 @@ impl CorrelationManager { completed_entry.end_time = Some(Utc::now()); let mut completed = completed_requests.write().await; - if completed.len() >= config.max_completed_requests { - if let Some(oldest_key) = completed.keys().next().cloned() { - completed.remove(&oldest_key); - } + if completed.len() >= config.max_completed_requests + && let Some(oldest_key) = completed.keys().next().cloned() + { + completed.remove(&oldest_key); } completed.insert(key, completed_entry); } diff --git a/pulseengine-logging/src/metrics.rs b/pulseengine-logging/src/metrics.rs index fa2a2776..bde79bd0 100644 --- a/pulseengine-logging/src/metrics.rs +++ b/pulseengine-logging/src/metrics.rs @@ -364,10 +364,10 @@ impl MetricsCollector { .push(duration_ms); // Keep only last 1000 response times per tool for memory efficiency - if let Some(times) = metrics.response_times_by_tool.get_mut(tool_name) { - if times.len() > 1000 { - times.drain(..times.len() - 1000); - } + if let Some(times) = metrics.response_times_by_tool.get_mut(tool_name) + && times.len() > 1000 + { + times.drain(..times.len() - 1000); } // Recalculate averages and percentiles diff --git a/pulseengine-logging/src/persistence.rs b/pulseengine-logging/src/persistence.rs index 61abb954..0974e156 100644 --- a/pulseengine-logging/src/persistence.rs +++ b/pulseengine-logging/src/persistence.rs @@ -331,10 +331,10 @@ async fn cleanup_old_files(data_dir: &Path, max_files: usize) -> Result<(), std: let entry = entry?; let path = entry.path(); - if path.extension().and_then(|s| s.to_str()) == Some("jsonl") { - if let Ok(metadata) = entry.metadata() { - files.push((path, metadata.modified()?)); - } + if path.extension().and_then(|s| s.to_str()) == Some("jsonl") + && let Ok(metadata) = entry.metadata() + { + files.push((path, metadata.modified()?)); } } diff --git a/pulseengine-logging/src/sanitization.rs b/pulseengine-logging/src/sanitization.rs index f689baa6..6c7fc798 100644 --- a/pulseengine-logging/src/sanitization.rs +++ b/pulseengine-logging/src/sanitization.rs @@ -161,17 +161,17 @@ impl LogSanitizer { } // Replace IP addresses if not preserved - if !self.config.preserve_ips { - if let Some(regex) = IP_REGEX.get() { - sanitized = regex.replace_all(&sanitized, "[IP_REDACTED]").to_string(); - } + if !self.config.preserve_ips + && let Some(regex) = IP_REGEX.get() + { + sanitized = regex.replace_all(&sanitized, "[IP_REDACTED]").to_string(); } // Replace UUIDs if not preserved - if !self.config.preserve_uuids { - if let Some(regex) = UUID_REGEX.get() { - sanitized = regex.replace_all(&sanitized, "[UUID_REDACTED]").to_string(); - } + if !self.config.preserve_uuids + && let Some(regex) = UUID_REGEX.get() + { + sanitized = regex.replace_all(&sanitized, "[UUID_REDACTED]").to_string(); } sanitized diff --git a/pulseengine-security/src/config.rs b/pulseengine-security/src/config.rs index 8b3db656..697e51a3 100644 --- a/pulseengine-security/src/config.rs +++ b/pulseengine-security/src/config.rs @@ -170,17 +170,17 @@ impl SecurityConfig { } // JWT expiry - if let Ok(jwt_expiry) = env::var("MCP_JWT_EXPIRY") { - if let Ok(expiry_seconds) = jwt_expiry.parse::() { - self.settings.jwt_expiry_seconds = expiry_seconds; - } + if let Ok(jwt_expiry) = env::var("MCP_JWT_EXPIRY") + && let Ok(expiry_seconds) = jwt_expiry.parse::() + { + self.settings.jwt_expiry_seconds = expiry_seconds; } // Rate limiting - if let Ok(rate_limit) = env::var("MCP_RATE_LIMIT") { - if let Some(config) = parse_rate_limit(&rate_limit) { - self.settings.rate_limit = config; - } + if let Ok(rate_limit) = env::var("MCP_RATE_LIMIT") + && let Some(config) = parse_rate_limit(&rate_limit) + { + self.settings.rate_limit = config; } // CORS origins diff --git a/pulseengine-security/src/middleware.rs b/pulseengine-security/src/middleware.rs index cc3f043d..20406d57 100644 --- a/pulseengine-security/src/middleware.rs +++ b/pulseengine-security/src/middleware.rs @@ -52,34 +52,33 @@ impl SecurityMiddleware { } // Try API key authentication first - if let Some(ref validator) = self.api_key_validator { - if let Some(api_key) = extract_api_key(headers) { - match validator.validate_api_key(&api_key) { - Ok(user_id) => { - let auth_context = AuthContext::new(user_id) - .with_api_key(api_key) - .with_role("api_user"); - return Ok(Some(auth_context)); - } - Err(e) => { - debug!("API key validation failed: {}", e); - } + if let Some(ref validator) = self.api_key_validator + && let Some(api_key) = extract_api_key(headers) + { + match validator.validate_api_key(&api_key) { + Ok(user_id) => { + let auth_context = AuthContext::new(user_id) + .with_api_key(api_key) + .with_role("api_user"); + return Ok(Some(auth_context)); + } + Err(e) => { + debug!("API key validation failed: {}", e); } } } // Try JWT token authentication - if let Some(ref validator) = self.token_validator { - if let Some(token) = extract_bearer_token(headers) { - match validator.validate_token(&token) { - Ok(claims) => { - let auth_context = - AuthContext::new(claims.sub.clone()).with_jwt_claims(claims); - return Ok(Some(auth_context)); - } - Err(e) => { - debug!("JWT validation failed: {}", e); - } + if let Some(ref validator) = self.token_validator + && let Some(token) = extract_bearer_token(headers) + { + match validator.validate_token(&token) { + Ok(claims) => { + let auth_context = AuthContext::new(claims.sub.clone()).with_jwt_claims(claims); + return Ok(Some(auth_context)); + } + Err(e) => { + debug!("JWT validation failed: {}", e); } } } @@ -202,24 +201,24 @@ pub struct RequestId(pub String); /// Extract API key from request headers fn extract_api_key(headers: &HeaderMap) -> Option { // Try Authorization header first - if let Some(auth_header) = headers.get("authorization") { - if let Ok(auth_str) = auth_header.to_str() { - if let Some(key) = auth_str.strip_prefix("ApiKey ") { - return Some(key.to_string()); - } - if let Some(key) = auth_str.strip_prefix("Bearer ") { - if key.starts_with("mcp_") { - return Some(key.to_string()); - } - } + if let Some(auth_header) = headers.get("authorization") + && let Ok(auth_str) = auth_header.to_str() + { + if let Some(key) = auth_str.strip_prefix("ApiKey ") { + return Some(key.to_string()); + } + if let Some(key) = auth_str.strip_prefix("Bearer ") + && key.starts_with("mcp_") + { + return Some(key.to_string()); } } // Try X-API-Key header - if let Some(key_header) = headers.get("x-api-key") { - if let Ok(key_str) = key_header.to_str() { - return Some(key_str.to_string()); - } + if let Some(key_header) = headers.get("x-api-key") + && let Ok(key_str) = key_header.to_str() + { + return Some(key_str.to_string()); } None @@ -227,14 +226,13 @@ fn extract_api_key(headers: &HeaderMap) -> Option { /// Extract Bearer token from request headers fn extract_bearer_token(headers: &HeaderMap) -> Option { - if let Some(auth_header) = headers.get("authorization") { - if let Ok(auth_str) = auth_header.to_str() { - if let Some(token) = auth_str.strip_prefix("Bearer ") { - // Make sure it's not an API key - if !token.starts_with("mcp_") { - return Some(token.to_string()); - } - } + if let Some(auth_header) = headers.get("authorization") + && let Ok(auth_str) = auth_header.to_str() + && let Some(token) = auth_str.strip_prefix("Bearer ") + { + // Make sure it's not an API key + if !token.starts_with("mcp_") { + return Some(token.to_string()); } } @@ -246,18 +244,17 @@ fn extract_client_id(request: &Request) -> String { // Try to get client IP from headers (proxy headers) let headers = request.headers(); - if let Some(forwarded_for) = headers.get("x-forwarded-for") { - if let Ok(ip_str) = forwarded_for.to_str() { - if let Some(first_ip) = ip_str.split(',').next() { - return first_ip.trim().to_string(); - } - } + if let Some(forwarded_for) = headers.get("x-forwarded-for") + && let Ok(ip_str) = forwarded_for.to_str() + && let Some(first_ip) = ip_str.split(',').next() + { + return first_ip.trim().to_string(); } - if let Some(real_ip) = headers.get("x-real-ip") { - if let Ok(ip_str) = real_ip.to_str() { - return ip_str.to_string(); - } + if let Some(real_ip) = headers.get("x-real-ip") + && let Ok(ip_str) = real_ip.to_str() + { + return ip_str.to_string(); } // Fallback to connection info (if available) @@ -275,25 +272,24 @@ fn is_https_request(request: &Request) -> bool { // Check forwarded protocol headers (common in proxy setups) let headers = request.headers(); - if let Some(forwarded_proto) = headers.get("x-forwarded-proto") { - if let Ok(proto_str) = forwarded_proto.to_str() { - return proto_str.to_lowercase() == "https"; - } + if let Some(forwarded_proto) = headers.get("x-forwarded-proto") + && let Ok(proto_str) = forwarded_proto.to_str() + { + return proto_str.to_lowercase() == "https"; } - if let Some(forwarded_ssl) = headers.get("x-forwarded-ssl") { - if let Ok(ssl_str) = forwarded_ssl.to_str() { - return ssl_str.to_lowercase() == "on"; - } + if let Some(forwarded_ssl) = headers.get("x-forwarded-ssl") + && let Ok(ssl_str) = forwarded_ssl.to_str() + { + return ssl_str.to_lowercase() == "on"; } // For development, assume localhost connections are acceptable - if let Some(host) = headers.get("host") { - if let Ok(host_str) = host.to_str() { - if host_str.starts_with("localhost") || host_str.starts_with("127.0.0.1") { - return true; - } - } + if let Some(host) = headers.get("host") + && let Ok(host_str) = host.to_str() + && (host_str.starts_with("localhost") || host_str.starts_with("127.0.0.1")) + { + return true; } false diff --git a/rust-toolchain.toml b/rust-toolchain.toml index 232a0661..93ee5263 100644 --- a/rust-toolchain.toml +++ b/rust-toolchain.toml @@ -1,7 +1,7 @@ [toolchain] # Pin Rust version to ensure consistency across all environments # This file is used by rustup to automatically install and use the correct toolchain -channel = "1.88" +channel = "1.89" components = ["rustfmt", "clippy", "llvm-tools-preview"] targets = ["x86_64-unknown-linux-gnu", "x86_64-apple-darwin", "x86_64-pc-windows-msvc"]