droneProject
Loading...
Searching...
No Matches
Train_PPO.py
Go to the documentation of this file.
1
# Import necessary libraries
2
from
droneRobot
import
DroneRobot
# Import the DroneRobot class to define the custom environment for the drone
3
from
stable_baselines3
import
PPO
# Import PPO (Proximal Policy Optimization) from Stable Baselines 3 for RL
4
from
stable_baselines3.common.vec_env
import
DummyVecEnv
# Import DummyVecEnv to handle vectorized environments
5
import
os, datetime
# Import os for file path management and datetime for generating unique model names
6
7
# Define the path for saving the trained model, using the current timestamp to create a unique filename
8
ppo_path = os.path.join(
'Training'
,
'Saved Models'
, f
'Drone_PPO_Model_{datetime.datetime.now().strftime("%Y_%m_%d_%H_%M_%S")}'
)
9
# The file path will be generated with the current date and time to ensure the model file is unique
10
log_path = os.path.join(
'Training'
,
'Logs'
)
# Path to store the TensorBoard logs for monitoring training progress
11
12
# Set the number of episodes to train the agent
13
episodes = 100
# Number of episodes to train the agent (a single run of the environment)
14
15
# Initialize the environment (the drone control simulation)
16
env =
DroneRobot
()
17
env.debugMode =
True
# Enable debug mode for additional information and debugging during training
18
env.startTensorBoard(log_path)
# Start TensorBoard to log training progress for visualization
19
20
# Calculate the total number of timesteps to train the agent
21
timesteps = env.steps_per_episode * episodes
# Total number of timesteps based on steps per episode and number of episodes
22
23
# Wrap the environment in a DummyVecEnv for compatibility with vectorized environments in Stable Baselines 3
24
env = DummyVecEnv([
lambda
: env])
# DummyVecEnv is used to handle a single environment in a vectorized format
25
26
# Initialize the PPO model with specific parameters
27
model = PPO(
'MlpPolicy'
, env, verbose=1, tensorboard_log=log_path, device=
'cpu'
, ent_coef=10, learning_rate=0.0001)
28
# 'MlpPolicy' specifies the use of a neural network-based policy (MLP)
29
# 'verbose=1' enables logging of training details during the process
30
# 'tensorboard_log' specifies where to store TensorBoard logs for training visualization
31
# 'device='cpu'' ensures that the model is trained on the CPU (change to 'cuda' for GPU)
32
# 'ent_coef=10' adjusts the entropy coefficient to control exploration (higher values promote more exploration)
33
# 'learning_rate=0.0001' sets the learning rate for the model optimization
34
35
# Start the training process for the specified number of timesteps
36
model.learn(total_timesteps=timesteps)
# Train the model for the calculated total timesteps
37
38
# Save the trained model to the specified path with a timestamp-based filename
39
model.save(ppo_path)
# Save the trained PPO model to the path defined earlier
40
41
# Print a message indicating that the training is complete
42
print(
"Done training"
)
# Display a message confirming the completion of the training process
droneRobot.DroneRobot
Definition
droneRobot.py:28
controllers
Train_PPO.py
Generated by
1.12.0