Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
154 changes: 154 additions & 0 deletions rsky-feedgen/src/apis/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -1119,6 +1119,29 @@ pub fn add_visitor(
Ok(())
}

pub fn ban_from_tv(
subject: &String,
reason_param: Option<String>,
tags_param: Option<Vec<String>>,
) -> Result<(), Box<dyn std::error::Error>> {
use crate::schema::banned_from_tv::dsl::*;

let connection = &mut establish_connection()?;

diesel::insert_into(banned_from_tv)
.values((
did.eq(subject.clone()),
reason.eq(reason_param),
tags.eq(tags_param),
))
.on_conflict(did)
.do_nothing()
.execute(connection)
.expect("Error inserting into banned_from_tv");

Ok(())
}

pub fn is_banned_from_tv(subject: &String) -> Result<bool, Box<dyn std::error::Error>> {
use crate::schema::banned_from_tv::dsl::*;

Expand All @@ -1134,6 +1157,57 @@ pub fn is_banned_from_tv(subject: &String) -> Result<bool, Box<dyn std::error::E
return if count > 0 { Ok(true) } else { Ok(false) };
}

pub fn unban_from_tv(subject: &String) -> Result<(), Box<dyn std::error::Error>> {
use crate::schema::banned_from_tv::dsl::*;

let connection = &mut establish_connection()?;

diesel::delete(banned_from_tv.filter(did.eq(subject.clone())))
.execute(connection)
.expect("Error deleting from banned_from_tv");

Ok(())
}

pub fn search_banned_from_tv(
search_did: Option<String>,
search_tag: Option<String>,
limit: Option<i64>,
offset: Option<i64>,
) -> Result<Vec<crate::models::BannedFromTv>, Box<dyn std::error::Error>> {
use crate::models::BannedFromTv;
use crate::schema::banned_from_tv::dsl::*;

let connection = &mut establish_connection()?;

let mut query = banned_from_tv
.select(BannedFromTv::as_select())
.order(createdAt.desc())
.into_boxed();

if let Some(search_did_val) = search_did {
query = query.filter(did.like(format!("%{}%", search_did_val)));
}

if let Some(search_tag_val) = search_tag {
query = query.filter(tags.contains(vec![Some(search_tag_val)]));
}

if let Some(limit_val) = limit {
query = query.limit(limit_val);
}

if let Some(offset_val) = offset {
query = query.offset(offset_val);
}

let results = query
.load(connection)
.expect("Error loading banned_from_tv records");

Ok(results)
}

pub async fn get_cursor(
service_: String,
connection: ReadReplicaConn,
Expand Down Expand Up @@ -1531,4 +1605,84 @@ mod tests {
)
.await;
}

#[test]
fn test_ban_from_tv() {
let test_did = "did:plc:test123".to_string();
let test_reason = Some("Test ban reason".to_string());
let test_tags = Some(vec!["spam".to_string(), "abuse".to_string()]);

// Insert the banned user
let result = ban_from_tv(&test_did, test_reason.clone(), test_tags.clone());
assert!(result.is_ok());

// Verify the user is banned
let is_banned = is_banned_from_tv(&test_did);
assert!(is_banned.is_ok());
assert_eq!(is_banned.unwrap(), true);

// Test duplicate ban (should not error due to on_conflict)
let duplicate_result = ban_from_tv(&test_did, None, None);
assert!(duplicate_result.is_ok());

// Unban the user
let unban_result = unban_from_tv(&test_did);
assert!(unban_result.is_ok());

// Verify they are no longer banned
let is_still_banned = is_banned_from_tv(&test_did);
assert_eq!(is_still_banned.unwrap(), false);
}

#[test]
fn test_search_banned_from_tv() {
let test_did1 = "did:plc:search001".to_string();
let test_did2 = "did:plc:search002".to_string();
let test_did3 = "did:plc:search003".to_string();

// Insert multiple banned users
let _ = ban_from_tv(
&test_did1,
Some("Reason 1".to_string()),
Some(vec!["tag1".to_string()]),
);
let _ = ban_from_tv(
&test_did2,
Some("Reason 2".to_string()),
Some(vec!["tag2".to_string()]),
);
let _ = ban_from_tv(
&test_did3,
Some("Reason 3".to_string()),
Some(vec!["tag1".to_string()]),
);

// Search without filters
let all_results = search_banned_from_tv(None, None, Some(100), None);
assert!(all_results.is_ok());
let all_banned = all_results.unwrap();
assert!(all_banned.len() >= 3);

// Search by DID
let did_results = search_banned_from_tv(Some("search001".to_string()), None, None, None);
assert!(did_results.is_ok());
let did_banned = did_results.unwrap();
assert!(did_banned.iter().any(|b| b.did.contains("search001")));

// Search by tag
let tag_results = search_banned_from_tv(None, Some("tag1".to_string()), None, None);
assert!(tag_results.is_ok());
let tag_banned = tag_results.unwrap();
assert!(tag_banned.len() >= 2);

// Test pagination
let limited_results = search_banned_from_tv(None, None, Some(1), Some(0));
assert!(limited_results.is_ok());
assert_eq!(limited_results.unwrap().len(), 1);

// Clean up
let _ = unban_from_tv(&test_did1);
let _ = unban_from_tv(&test_did2);
let _ = unban_from_tv(&test_did3);
}
}
3 changes: 3 additions & 0 deletions rsky-feedgen/src/main.rs
Original file line number Diff line number Diff line change
Expand Up @@ -153,6 +153,9 @@ fn rocket() -> _ {
well_known,
get_cursor,
update_cursor,
ban_user,
unban_user,
list_banned_users,
all_options
],
)
Expand Down
16 changes: 16 additions & 0 deletions rsky-feedgen/src/models/banned_from_tv.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,16 @@
use diesel::prelude::*;

#[derive(Queryable, Selectable, Clone, Debug, PartialEq, Serialize, Deserialize)]
#[diesel(table_name = crate::schema::banned_from_tv)]
#[diesel(check_for_backend(diesel::pg::Pg))]
pub struct BannedFromTv {
#[serde(rename = "did")]
pub did: String,
#[serde(rename = "reason", skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
#[serde(rename = "createdAt", skip_serializing_if = "Option::is_none")]
#[diesel(column_name = createdAt)]
pub created_at: Option<String>,
#[serde(rename = "tags", skip_serializing_if = "Option::is_none")]
pub tags: Option<Vec<Option<String>>>,
}
11 changes: 11 additions & 0 deletions rsky-feedgen/src/models/banned_from_tv_request.rs
Original file line number Diff line number Diff line change
@@ -0,0 +1,11 @@
use serde::{Deserialize, Serialize};

#[derive(Clone, Debug, PartialEq, Serialize, Deserialize)]
pub struct BannedFromTvRequest {
#[serde(rename = "did")]
pub did: String,
#[serde(rename = "reason", skip_serializing_if = "Option::is_none")]
pub reason: Option<String>,
#[serde(rename = "tags", skip_serializing_if = "Option::is_none")]
pub tags: Option<Vec<String>>,
}
4 changes: 4 additions & 0 deletions rsky-feedgen/src/models/mod.rs
Original file line number Diff line number Diff line change
Expand Up @@ -31,3 +31,7 @@ pub mod known_service;
pub use self::known_service::KnownService;
pub mod jwt_parts;
pub use self::jwt_parts::JwtParts;
pub mod banned_from_tv;
pub use self::banned_from_tv::BannedFromTv;
pub mod banned_from_tv_request;
pub use self::banned_from_tv_request::BannedFromTvRequest;
69 changes: 69 additions & 0 deletions rsky-feedgen/src/routes.rs
Original file line number Diff line number Diff line change
Expand Up @@ -539,3 +539,72 @@ pub async fn well_known() -> Result<
}
}
}

#[rocket::post("/admin/ban", format = "json", data = "<body>")]
pub fn ban_user(
body: Json<crate::models::BannedFromTvRequest>,
_key: ApiKey<'_>,
) -> Result<(), status::Custom<Json<crate::models::InternalErrorMessageResponse>>> {
match crate::apis::ban_from_tv(&body.did, body.reason.clone(), body.tags.clone()) {
Ok(_) => Ok(()),
Err(error) => {
eprintln!("Internal Error: {error}");
let internal_error = crate::models::InternalErrorMessageResponse {
code: Some(crate::models::InternalErrorCode::InternalError),
message: Some(error.to_string()),
};
Err(status::Custom(
Status::InternalServerError,
Json(internal_error),
))
}
}
}

#[rocket::delete("/admin/ban?<did>")]
pub fn unban_user(
did: &str,
_key: ApiKey<'_>,
) -> Result<(), status::Custom<Json<crate::models::InternalErrorMessageResponse>>> {
match crate::apis::unban_from_tv(&did.to_string()) {
Ok(_) => Ok(()),
Err(error) => {
eprintln!("Internal Error: {error}");
let internal_error = crate::models::InternalErrorMessageResponse {
code: Some(crate::models::InternalErrorCode::InternalError),
message: Some(error.to_string()),
};
Err(status::Custom(
Status::InternalServerError,
Json(internal_error),
))
}
}
}

#[rocket::get("/admin/banned?<did>&<tag>&<limit>&<offset>", format = "json")]
pub fn list_banned_users(
did: Option<String>,
tag: Option<String>,
limit: Option<i64>,
offset: Option<i64>,
_key: ApiKey<'_>,
) -> Result<
Json<Vec<crate::models::BannedFromTv>>,
status::Custom<Json<crate::models::InternalErrorMessageResponse>>,
> {
match crate::apis::search_banned_from_tv(did, tag, limit, offset) {
Ok(results) => Ok(Json(results)),
Err(error) => {
eprintln!("Internal Error: {error}");
let internal_error = crate::models::InternalErrorMessageResponse {
code: Some(crate::models::InternalErrorCode::InternalError),
message: Some(error.to_string()),
};
Err(status::Custom(
Status::InternalServerError,
Json(internal_error),
))
}
}
}