Skip to content
Merged
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
23 changes: 13 additions & 10 deletions mw_api/DB_bots/db_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -10,8 +10,10 @@
# db.select(table_name, args)

"""

# ---
from typing import Dict, Any, List, Optional, Iterator
from typing import Any, Dict, Iterator, List, Optional

import sqlite_utils


Expand All @@ -25,9 +27,13 @@ def __init__(self, db_path: str) -> None:
# self.db = sqlite_utils.Database(db_path, tracer=tracer)
self.db = sqlite_utils.Database(db_path)

def create_table(self, table_name: str, fields: Dict[str, Any], pk: str = "id", **kwargs) -> None:
def create_table(
self, table_name: str, fields: Dict[str, Any], pk: str = "id", **kwargs
) -> None:
# Create table if it doesn't exist
self.db[table_name].create(fields, pk=pk, if_not_exists=True, ignore=True, **kwargs)
self.db[table_name].create(
fields, pk=pk, if_not_exists=True, ignore=True, **kwargs
)

def query(self, sql: str) -> List[tuple]:
# return self.db.query(sql)
Expand All @@ -52,7 +58,9 @@ def insert(self, table_name: str, data: Dict[str, Any], check: bool = True) -> N
self.db[table_name].insert(data, ignore=True, pk="id")
del data

def insert_all(self, table_name: str, datalist: List[Dict[str, Any]], prnt: bool = True) -> None:
def insert_all(
self, table_name: str, datalist: List[Dict[str, Any]], prnt: bool = True
) -> None:
if prnt:
print(f"inserting {len(datalist)} rows")
self.db[table_name].insert_all(datalist, ignore=True, pk="id")
Expand Down Expand Up @@ -269,12 +277,7 @@ def table_exists(self, table: str) -> bool:
"""
return table in self._db.table_names()

def create_table(
self,
table: str,
schema: Dict[str, type],
pk: str = "id"
) -> None:
def create_table(self, table: str, schema: Dict[str, type], pk: str = "id") -> None:
"""
Create a table if it doesn't exist.

Expand Down
17 changes: 15 additions & 2 deletions mw_api/DB_bots/pymysql_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -2,14 +2,27 @@
from mw_api import pymysql_bot
# result = pymysql_bot.sql_connect_pymysql(query, return_dict=False, values=None, main_args={}, credentials={}, conversions=None)
"""

import copy

import pymysql
import pymysql.cursors


def sql_connect_pymysql(query, return_dict=False, values=None, main_args={}, credentials={}, conversions=None, many=False, **kwargs):
def sql_connect_pymysql(
query,
return_dict=False,
values=None,
main_args={},
credentials={},
conversions=None,
many=False,
**kwargs
):
args = copy.deepcopy(main_args)
args["cursorclass"] = pymysql.cursors.DictCursor if return_dict else pymysql.cursors.Cursor
args["cursorclass"] = (
pymysql.cursors.DictCursor if return_dict else pymysql.cursors.Cursor
)
if conversions:
args["conv"] = conversions

Expand Down
13 changes: 8 additions & 5 deletions mw_api/__init__.py
Original file line number Diff line number Diff line change
@@ -1,11 +1,14 @@
# -*- coding: utf-8 -*-
"""
"""
""" """
from pathlib import Path

from .all_apis import ALL_APIS
from .api_utils import botEdit, printe, txtlib, wd_sparql
from .DB_bots import db_bot, pymysql_bot
from .api_utils import botEdit
from .api_utils import printe, txtlib, wd_sparql
from .logging_config import setup_logging
from .super.login_wrap import LoginWrap
from .all_apis import ALL_APIS

setup_logging(Path(__name__).parent.name)

__all__ = [
"ALL_APIS",
Expand Down
8 changes: 3 additions & 5 deletions mw_api/all_apis.py
Original file line number Diff line number Diff line change
@@ -1,8 +1,6 @@
"""
"""
from .pages_bots.all_apis import (
ALL_APIS,
)
""" """

from .pages_bots.all_apis import ALL_APIS

__all__ = [
"ALL_APIS",
Expand Down
1 change: 1 addition & 0 deletions mw_api/api/__init__.py
Original file line number Diff line number Diff line change
Expand Up @@ -4,6 +4,7 @@

Contains API client and related query utilities.
"""

from .client import MediaWikiApiClient
from .token_manager import TokenManager

Expand Down
47 changes: 18 additions & 29 deletions mw_api/api/client.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,12 +5,16 @@
Provides the MediaWikiApiClient class for low-level API operations,
abstracting HTTP requests and response parsing.
"""
from typing import Dict, Any, Optional, List

import logging
from typing import Any, Dict, List, Optional

import requests

from ..core.request_config import RequestConfig
from ..core.exceptions import ApiError, parse_api_error
from ..api_utils import printe
from ..core.request_config import RequestConfig

logger = logging.getLogger(__name__)


class MediaWikiApiClient:
Expand All @@ -30,7 +34,7 @@ def __init__(
self,
endpoint: str,
session: Optional[requests.Session] = None,
user_agent: str = "mw_api Python client"
user_agent: str = "mw_api Python client",
) -> None:
"""
Initialize the API client.
Expand All @@ -47,9 +51,7 @@ def __init__(
self._headers = {"User-Agent": self.user_agent}

def post(
self,
params: Dict[str, Any],
config: Optional[RequestConfig] = None
self, params: Dict[str, Any], config: Optional[RequestConfig] = None
) -> Dict[str, Any]:
"""
Make a POST request to the API.
Expand Down Expand Up @@ -80,19 +82,17 @@ def post(
data=params,
headers=self._headers,
timeout=config.timeout,
files=config.files
files=config.files,
)
response.raise_for_status()
return response.json()

except requests.RequestException as e:
printe.warn(str(e))
logger.warning(str(e))
return {}

def get(
self,
params: Dict[str, Any],
config: Optional[RequestConfig] = None
self, params: Dict[str, Any], config: Optional[RequestConfig] = None
) -> Dict[str, Any]:
"""
Make a GET request to the API.
Expand All @@ -113,19 +113,17 @@ def get(
self.endpoint,
params=params,
headers=self._headers,
timeout=config.timeout
timeout=config.timeout,
)
response.raise_for_status()
return response.json()

except requests.RequestException as e:
printe.warn(str(e))
logger.warning(str(e))
return {}

def request(
self,
params: Dict[str, Any],
config: Optional[RequestConfig] = None
self, params: Dict[str, Any], config: Optional[RequestConfig] = None
) -> Dict[str, Any]:
"""
Make a request using the configured method.
Expand All @@ -144,9 +142,7 @@ def request(
return self.post(params, config)

def query(
self,
params: Dict[str, Any],
continue_key: str = "continue"
self, params: Dict[str, Any], continue_key: str = "continue"
) -> List[Dict[str, Any]]:
"""
Make a query with automatic continuation.
Expand Down Expand Up @@ -194,11 +190,7 @@ def fetch_csrf_token(self) -> str:
Returns:
The CSRF token string.
"""
params = {
"action": "query",
"meta": "tokens",
"type": "csrf"
}
params = {"action": "query", "meta": "tokens", "type": "csrf"}
result = self.post(params)
token = result.get("query", {}).get("tokens", {}).get("csrftoken", "")
self._csrf_token = token
Expand All @@ -211,10 +203,7 @@ def is_authenticated(self) -> bool:
Returns:
True if authenticated, False otherwise.
"""
params = {
"action": "query",
"meta": "userinfo"
}
params = {"action": "query", "meta": "userinfo"}
result = self.post(params)
userinfo = result.get("query", {}).get("userinfo", {})
return "anon" not in userinfo and "id" in userinfo
15 changes: 8 additions & 7 deletions mw_api/api/token_manager.py
Original file line number Diff line number Diff line change
Expand Up @@ -5,10 +5,13 @@
Provides the TokenManager class for managing authentication tokens,
extracted from super_login.py for better separation of concerns.
"""
from typing import Optional, Dict, Any

import logging
from typing import Any, Dict, Optional

import requests

from ..api_utils import printe
logger = logging.getLogger(__name__)


class TokenManager:
Expand All @@ -26,7 +29,7 @@ def __init__(
self,
endpoint: str,
session: requests.Session,
headers: Optional[Dict[str, str]] = None
headers: Optional[Dict[str, str]] = None,
) -> None:
"""
Initialize the TokenManager.
Expand Down Expand Up @@ -84,16 +87,14 @@ def _fetch_token(self, token_type: str) -> str:

try:
response = self._session.post(
self._endpoint,
data=params,
headers=self._headers
self._endpoint, data=params, headers=self._headers
)
data = response.json()
token_key = f"{token_type}token"
return data.get("query", {}).get("tokens", {}).get(token_key, "")

except Exception as e:
printe.warn(str(e))
logger.warning(str(e))
return ""

def invalidate_csrf(self) -> None:
Expand Down
38 changes: 27 additions & 11 deletions mw_api/api_utils/ask_bot.py
Original file line number Diff line number Diff line change
Expand Up @@ -3,18 +3,32 @@
from ...api_utils.ask_bot import ASK_BOT

"""
from . import printe

import logging

from ..core.config import get_default_config
from . import printe

logger = logging.getLogger(__name__)
yes_answer = ["y", "a", "", "Y", "A", "all", "aaa"]

Save_or_Ask = {}


class ASK_BOT:
def __init__(self):
pass

def ask_put(self, nodiff=False, newtext="", text="", message="", job="Genral", username="", summary=""):
def ask_put(
self,
nodiff=False,
newtext="",
text="",
message="",
job="Genral",
username="",
summary="",
):
"""
Prompts the user to confirm saving changes to a page, optionally displaying a diff.

Expand All @@ -36,27 +50,29 @@ def ask_put(self, nodiff=False, newtext="", text="", message="", job="Genral", u
if len(newtext) < 70000 and len(text) < 70000 or config.show_diff:
printe.showDiff(text, newtext)
else:
printe.output("showDiff error..")
logger.info("showDiff error..")
# ---
printe.output(f"diference in bytes: {len(newtext) - len(text):,}")
printe.output(f"len of text: {len(text):,}, len of newtext: {len(newtext):,}")
logger.info(f"diference in bytes: {len(newtext) - len(text):,}")
logger.info(
f"len of text: {len(text):,}, len of newtext: {len(newtext):,}"
)
# ---
if summary:
printe.output(f"-Edit summary: {summary}")
logger.info(f"-Edit summary: {summary}")
# ---
printe.output(f"<<lightyellow>>ASK_BOT: {message}? (yes, no) {username=}")
logger.info(f"<<lightyellow>>ASK_BOT: {message}? (yes, no) {username=}")
# ---
sa = input("([y]es, [N]o, [a]ll)?")
# ---
if sa == "a":
Save_or_Ask[job] = True
# ---
printe.output("<<lightgreen>> ---------------------------------")
printe.output(f"<<lightgreen>> save all:{job} without asking.")
printe.output("<<lightgreen>> ---------------------------------")
logger.info("<<lightgreen>> ---------------------------------")
logger.info(f"<<lightgreen>> save all:{job} without asking.")
logger.info("<<lightgreen>> ---------------------------------")
# ---
if sa not in yes_answer:
printe.output("wrong answer")
logger.info("wrong answer")
return False
# ---
return True
Loading