droneProject
Loading...
Searching...
No Matches
Train_PPO.py
Go to the documentation of this file.
1# Import necessary libraries
2from droneRobot import DroneRobot # Import the DroneRobot class to define the custom environment for the drone
3from stable_baselines3 import PPO # Import PPO (Proximal Policy Optimization) from Stable Baselines 3 for RL
4from stable_baselines3.common.vec_env import DummyVecEnv # Import DummyVecEnv to handle vectorized environments
5import os, datetime # Import os for file path management and datetime for generating unique model names
6
7# Define the path for saving the trained model, using the current timestamp to create a unique filename
8ppo_path = os.path.join('Training', 'Saved Models', f'Drone_PPO_Model_{datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")}')
9# The file path will be generated with the current date and time to ensure the model file is unique
10log_path = os.path.join('Training', 'Logs') # Path to store the TensorBoard logs for monitoring training progress
11
12# Set the number of episodes to train the agent
13episodes = 100 # Number of episodes to train the agent (a single run of the environment)
14
15# Initialize the environment (the drone control simulation)
17env.debugMode = True # Enable debug mode for additional information and debugging during training
18env.startTensorBoard(log_path) # Start TensorBoard to log training progress for visualization
19
20# Calculate the total number of timesteps to train the agent
21timesteps = env.steps_per_episode * episodes # Total number of timesteps based on steps per episode and number of episodes
22
23# Wrap the environment in a DummyVecEnv for compatibility with vectorized environments in Stable Baselines 3
24env = DummyVecEnv([lambda: env]) # DummyVecEnv is used to handle a single environment in a vectorized format
25
26# Initialize the PPO model with specific parameters
27model = PPO('MlpPolicy', env, verbose=1, tensorboard_log=log_path, device='cpu', ent_coef=10, learning_rate=0.0001)
28# 'MlpPolicy' specifies the use of a neural network-based policy (MLP)
29# 'verbose=1' enables logging of training details during the process
30# 'tensorboard_log' specifies where to store TensorBoard logs for training visualization
31# 'device='cpu'' ensures that the model is trained on the CPU (change to 'cuda' for GPU)
32# 'ent_coef=10' adjusts the entropy coefficient to control exploration (higher values promote more exploration)
33# 'learning_rate=0.0001' sets the learning rate for the model optimization
34
35# Start the training process for the specified number of timesteps
36model.learn(total_timesteps=timesteps) # Train the model for the calculated total timesteps
37
38# Save the trained model to the specified path with a timestamp-based filename
39model.save(ppo_path) # Save the trained PPO model to the path defined earlier
40
41# Print a message indicating that the training is complete
42print("Done training") # Display a message confirming the completion of the training process