droneProject
Loading...
Searching...
No Matches
Train_SAC.py
Go to the documentation of this file.
1# Import necessary libraries
2from droneRobot import DroneRobot # Import the DroneRobot class, which contains the drone environment
3import numpy as np # Import numpy for numerical operations, though it is not used in this snippet
4from stable_baselines3 import SAC # Import the SAC (Soft Actor-Critic) algorithm from Stable Baselines 3
5import os # Import os module for file and directory path management
6
7# Define the path for saving the trained model and logging
8model_path = os.path.join('Training', 'Saved Models', 'SAC_model') # Path to save the trained model
9log_path = os.path.join('Training', 'Logs') # Path to save TensorBoard logs for monitoring the training process
10
11# Initialize the environment (the drone control simulation)
13env.target_location = [0, 0, 2] #X Y Z target location, go to this location. All reward is calculated based on this
14env.debugMode = False # Enable debug mode for testing your PID constants, defined in take_action() function
15# Set the number of episodes to run and compute total timesteps based on the steps per episode
16episodes = 1000 # Number of episodes to train the agent
17timesteps = env.steps_per_episode * episodes # Total timesteps for training, calculated based on steps per episode
18
19# Start TensorBoard for monitoring the training progress (log path specified)
20env.startTensorBoard(log_path) # Initializes TensorBoard logging for visualization during training
21
22# Create the SAC model using the MLP (Multi-Layer Perceptron) policy for continuous control tasks
23model = SAC("MlpPolicy", env, verbose=1, device='cuda', tensorboard_log=log_path)
24# "MlpPolicy" specifies the use of a neural network-based policy
25# "verbose=1" enables some logging details during training
26# "device='cuda'" ensures the model uses the GPU if available for faster training
27# "tensorboard_log" specifies the log directory for TensorBoard visualization
28
29# Start training the model for the specified number of timesteps
30model.learn(total_timesteps=timesteps, log_interval=4)
31# "total_timesteps" defines how long the model should be trained
32# "log_interval=4" sets the frequency of logging to TensorBoard every 4 updates
33
34# Save the trained model to the specified directory
35model.save("sac_Drone", path=model_path)
36# Saves the trained model under the name "sac_Drone" in the defined path