A lightweight framework for training Deep Q-Network (DQN) agents in discrete action space environments using Gymnasium.
Currently supported configurations:
CartPole-v1MountainCar-v0Acrobot-v1ALE/Pong-v5(CNN DQN)
train.py: DQN training + best model saving + progress plot.evaluate.py: trained model evaluation (greedy policy, epsilon=0) with statistics summary.play.py: running the trained model inrender_mode="human"mode.agents/dqn_agent.py: agent logic (epsilon-greedy, training step, soft update target network).models/dqn_network.py: MLP built dynamically from a hidden layers list + factorycreate_network(config, state_shape, action_dim).models/cnn_dqn_network.py: CNN DQN with configurable Conv2d layers and Dueling support.memory/replay_buffer.py: three replay buffer variants —ReplayBuffer(uniform),PrioritizedReplayBuffer(PER),NstepReplayBuffer(N-step returns) — with factorycreate_buffer(config).utils/evaluate.py: sharedevaluate_policy()function used byevaluate.pyandtrain.py.utils/wrappers.py:make_env()withframe_skip,wrap_env()with image preprocessing (Atari + generic).config/config.py: centralized hyperparameter configuration and per-environment presets.
- Python 3.10+
- Python packages:
torchgymnasiumgymnasium[atari],ale-py(optional, required for Atari environments)opencv-python(required byAtariPreprocessing)numpymatplotlibtensorboard
Example installation:
pip install torch gymnasium numpy matplotlib tensorboard
# Optional — for Atari environments (e.g. ALE/Pong-v5):
pip install "gymnasium[atari]" ale-py opencv-pythonThe base version.py + README.md contract is now enforced by the shared validator from AI_Instruction, using the repo-local descriptors in .github/versioning/.
CI runs this shared check as a blocking gate before the more expensive jobs, and the optional local .pre-commit-config.yaml calls the same shared hook for faster feedback. CI remains the source of truth.
The repository still includes a stricter local validator for release metadata consistency:
python scripts/validate_version_consistency.py --repo-root .This optional script additionally checks that the matching CHANGELOG.md section stays in sync, so it can still be used as a manual release-prep helper when a stricter local check is useful.
- Training:
python train.py- Training for a specific environment:
python train.py CartPole-v1
python train.py MountainCar-v0
python train.py Acrobot-v1
python train.py ALE/Pong-v5- Training with a specific seed (overrides the config value):
python train.py MountainCar-v0 --seed 123- Watching the trained agent:
python play.py
python play.py CartPole-v1
python play.py MountainCar-v0 --play-episodes 10- Evaluating the trained model (greedy policy, epsilon=0):
python evaluate.py CartPole-v1
python evaluate.py MountainCar-v0 --episodes 50
python evaluate.py Acrobot-v1 --episodes 100 --render
python evaluate.py CartPole-v1 --render --render-episodes 5- Running unit tests:
pytest tests/ -v- Running tests with coverage report:
pytest tests/ --cov=config --cov=agents --cov=memory --cov=utils --cov=models --cov-report=term-missing- The agent selects actions using epsilon-greedy.
- Transitions are stored in a Replay Buffer (uniform, PER, or N-step — configurable via
buffer_type). - Network updates use the Double DQN variant:
- action selection
argmaxthroughpolicy_net, - evaluation of that action through
target_net.
- action selection
target_netis updated via soft update with parametertau(or hard update everytarget_update_freqsteps for CNN).- For
CartPole-v1, terminal transitions (terminated) receive a training penalty of-10.0. - Loss function: Smooth L1 (Huber loss).
- For image-based environments (e.g.
ALE/Pong-v5), a CNN DQN network with frame preprocessing (grayscale, resize, frame stacking) is used.
Configuration is located in config/config.py in the Config class.
Key fields:
gamma,lr,batch_size,memory_sizeepsilon,epsilon_decay,epsilon_mintauhidden_layersnum_episodes,min_replay_size,train_every_stepssolved_thresholdmodel_path,plot_path,play_episodes
Replay Buffer:
buffer_type: buffer type —"replay"(uniform),"prioritized"(PER),"nstep"(N-step returns). Default:"prioritized". Automatically setsuse_per.nstep_n: number of N-step return steps (only whenbuffer_type: "nstep"). Default:3.
buffer_type |
Class | Description |
|---|---|---|
"replay" |
ReplayBuffer |
Uniform sampling from deque. Simple and fast. |
"prioritized" |
PrioritizedReplayBuffer |
PER with IS weights. Better for sparse rewards. |
"nstep" |
NstepReplayBuffer |
N-step returns + uniform. Accelerates value propagation. |
PER parameters (active when buffer_type: "prioritized"):
per_alpha: prioritization strength (0.0 = uniform sampling).per_beta_start: initial beta value for IS weights.per_beta_frames: number of steps to anneal beta to 1.0.per_eps: small constant added to priority for numerical stability.
Architecture parameters:
use_dueling: enables/disables Dueling DQN (default:False). WhenTrue, the network separates state value and action advantage estimation for better generalization. All artifacts (model, logs, metrics, plots) are stored separately for standard DQN and Dueling DQN using a suffix (_standardor_dueling).network_type: network type —"mlp"(default) or"cnn"(for image-based environments).conv_layers: list of Conv2d layers as tuples(out_channels, kernel_size, stride). Default:[(32, 8, 4), (64, 4, 2), (64, 3, 1)].cnn_hidden_dim: hidden layer size after the CNN trunk.frame_stack: number of frames stacked as observation (default:4).frame_size: target frame size[H, W](default:[84, 84]).frame_skip: number of frames skipped (action repeat, default:1).is_atari: flag controlling the selection of Atari vs generic wrappers.target_update_freq: hard update target network every N steps (0 = soft update withtau).
Evaluation parameters:
eval_every: how often (in training episodes) greedy policy evaluation is run (default:100).eval_episodes: number of evaluation episodes (default:10).
By default, python train.py runs the preset for CartPole-v1.
To run training for another supported environment, pass it as an argument:
python train.py MountainCar-v0
python train.py Acrobot-v1Note: only environments defined in Config.ENV_CONFIG are supported.
During training, the following are saved:
- model weights file (
*.pth) toconfig.model_path, - training curve plot (
*.png) toconfig.plot_path, - TensorBoard logs to the
logs/<env_name><suffix>_<YYYYMMDD-HHMMSS>/directory, - episode metrics to CSV:
metrics/<env_name>_<model_name>_<YYYYMMDD-HHMMSS>.csv.
Example suffix: _standard for standard DQN, _dueling for Dueling DQN. The suffix is automatically appended to model and plot names, and consequently also visible in metric file names.
To view metrics during/after training:
tensorboard --logdir logsLogged metrics include:
episode/rewardepisode/avg100episode/epsilonepisode/lossepisode/q_meantrain/losstrain/q_meantrain/q_max_meantrain/target_q_meantrain/td_error_meantrain/beta(whenbuffer_type: "prioritized")train/is_weight_mean(whenbuffer_type: "prioritized")train/priority_mean(whenbuffer_type: "prioritized")eval/mean_reward(greedy policy)eval/std_reward(greedy policy)eval/min_reward(greedy policy)eval/max_reward(greedy policy)
CSV columns:
episoderewardavg100epsilonbetais_weight_meantd_error_meanpriority_mean
A separate evaluation CSV file (*_eval.csv) contains columns:
episodemean_rewardstd_rewardmin_rewardmax_reward
Example artifacts visible in the repository:
dqn_cartpole.pthtraining_curve_cartpole.png
play.pywill raiseFileNotFoundErrorif the model does not exist.- For reproducible results, set
seedinConfig.DEFAULTSor use the--seedflag when callingtrain.py(e.g.python train.py MountainCar-v0 --seed 42). The--seedflag overrides the config value.
See CHANGELOG.md for a detailed history of changes.
This project is licensed under the MIT License.