Skip to content

Commit

Permalink
add neat
Browse files Browse the repository at this point in the history
  • Loading branch information
Fer14 committed Aug 28, 2024
1 parent f80b32e commit 9f16762
Show file tree
Hide file tree
Showing 5 changed files with 174 additions and 0 deletions.
Empty file added neat_/__init__.py
Empty file.
Binary file added neat_/car.png
Loading
Sorry, something went wrong. Reload?
Sorry, we cannot display this file.
Sorry, this file is invalid so it cannot be displayed.
71 changes: 71 additions & 0 deletions neat_/neat_car.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,71 @@
import pygame
import math
import os
import sys
import neat
import pickle

# Add this at the top of the file
import sys
from pathlib import Path

sys.path.append(str(Path(__file__).resolve().parent.parent))


from car import Car


CAR_SIZE_X = 30
CAR_SIZE_Y = 30

BORDER_COLOR = (255, 255, 255, 255) # Color To Crash on Hit


class NeatCar(Car):

def __init__(self, net=None, position=None):
super().__init__(position=position, angle=0)
self.net = net
self.sprite = pygame.image.load("./neat_/car.png").convert()
self.sprite = pygame.transform.scale(self.sprite, (CAR_SIZE_X, CAR_SIZE_Y))
self.rotated_sprite = self.sprite

def load_net(self):
local_dir = os.path.dirname(__file__)
config_path = os.path.join(local_dir, "config.txt")
config = neat.Config(
neat.DefaultGenome,
neat.DefaultReproduction,
neat.DefaultSpeciesSet,
neat.DefaultStagnation,
config_path,
)
with open("./neat_/checkpoints/2024-08-25/best_genome.pickle", "rb") as f:
genome = pickle.load(f)
self.net = neat.nn.FeedForwardNetwork.create(genome, config)

def get_reward(self):
# Calculate reward based on distance and velocity
distance_reward = self.distance / (CAR_SIZE_X / 2)
velocity_reward = self.speed / 20 # Assuming max speed is 20, adjust as needed

# Combine the rewards (you can adjust the weights)
total_reward = 0.7 * distance_reward + 0.3 * velocity_reward

return total_reward
# return self.distance / (CAR_SIZE_X / 2)

def action(self):
input = self.get_data()
output = self.net.activate(input)
choice = output.index(max(output))

if choice == 0:
self.angle += 10 # Left
elif choice == 1:
self.angle -= 10 # Right
elif choice == 2:
if self.speed - 2 >= 6:
self.speed -= 2 # Slow Down
else:
self.speed += 2 # Speed Up
42 changes: 42 additions & 0 deletions neat_/neat_race.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,42 @@
import pygame
from race import Race
from neat_car import NeatCar


class NeatRace(Race):

def training_race(self, cars: list[NeatCar], genomes):
clock = pygame.time.Clock()

counter = 0

running = True
while running:
clock.tick(120) # Cap the frame rate

# Exit On Quit Event
for event in pygame.event.get():
if event.type == pygame.QUIT:
# throw exception to stop the race
raise KeyboardInterrupt("Race interrupted")

# For Each Car Get The Acton It Takes
for car in cars:
car.action()
# Check If Car Is Still Alive
# Increase Fitness If Yes And Break Loop If Not
still_alive = 0
for i, car in enumerate(cars):
if car.is_alive():
still_alive += 1
car.update(self.game_map)
genomes[i][1].fitness += car.get_reward()

if still_alive == 0:
break

counter += 1
if counter == 30 * 40: # Stop After About 20 Seconds
break

self.draw(cars)
61 changes: 61 additions & 0 deletions neat_/train_neat.py
Original file line number Diff line number Diff line change
@@ -0,0 +1,61 @@
import os
import neat
from datetime import date
from neat_car import NeatCar
from neat_race import NeatRace

import sys
from pathlib import Path

sys.path.append(str(Path(__file__).resolve().parent.parent))

from race import Race


race = NeatRace(start=(1295, 966))


def eval_genomes(genomes, config, shuffle=False):

cars = []

for i, g in genomes:
net = neat.nn.FeedForwardNetwork.create(g, config)
g.fitness = 0
cars.append(NeatCar(net=net, position=race.start))

# race.load_random_map()
race.training_race(cars, genomes)


def run_neat(config):
p = neat.Population(config)
p.add_reporter(neat.StdOutReporter(True))
stats = neat.StatisticsReporter()
p.add_reporter(stats)

try:
winner = p.run(eval_genomes, 200)
except KeyboardInterrupt:
print("\nTraining interrupted. Saving current best genome...")
winner = p.best_genome
finally:
# Save the winner genome
# with open(f"checkpoints/{date.today()}/best_genome.pickle", "wb") as f:
# pickle.dump(winner, f)
print(f"Best genome saved to checkpoints/{date.today()}/best_genome.pickle")


if __name__ == "__main__":
local_dir = os.path.dirname(__file__)
config_path = os.path.join(local_dir, "config.txt")
os.makedirs(f"checkpoints/{date.today()}", exist_ok=True)

config = neat.Config(
neat.DefaultGenome,
neat.DefaultReproduction,
neat.DefaultSpeciesSet,
neat.DefaultStagnation,
config_path,
)
run_neat(config)

0 comments on commit 9f16762

Please sign in to comment.