# gas.py
# Author: Ronald L. Rivest
# Date: March 3, 2007
# Gas simulation (aka bouncing colored balls)
# Uses PyGame for graphics (see www.pygame.org)

###########################################################################
### License stuff                                                       ###
###########################################################################
"""
Copyright (C) 2006  Ronald L. Rivest

This program is free software; you can redistribute it and/or modify
it under the terms of the GNU General Public License as published by
the Free Software Foundation; either version 2 of the License, or (at
your option) any later version.

This program is distributed in the hope that it will be useful, but
WITHOUT ANY WARRANTY; without even the implied warranty of
MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the GNU
General Public License for more details.

You should have received a copy of the GNU General Public License
along with this program; if not, write to the Free Software
Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston, MA
02110-1301, USA.

[Note: the pygame library, upon which this is based, comes with the
 Lesser GPL license (LGPL).]

Email: rivest@mit.edu
Mail:  Room 32-G692, MIT, Cambridge, MA 02139
"""
###########################################################################
###########################################################################

import math
import random
import time
import sys

# Global constants:

# coordinates of "walls" of world
world_min_x = -4096.0            # minimum x in world coordinates
world_max_x = +4096.0            # maximum x in world coordinates
world_min_y = -4096.0            # minimum y in world coordinates
world_max_y = +4096.0            # maximum y in world coordinates

ball_min_radius = 16.0           # minimum radius for ball (world units)
ball_max_radius = 128.0          # maximum radius for ball

# The number of balls will be a number in this list
number_balls_list = [0,1,2,3,4,5,6,7,8,9,10,15,20,25,30,40,50,75,100,
                     125,150,200,250,300,350,400,450,500,600,700,800,
                     900,1000,1200,1400,1600,1800,2000]

# Global variables:

balls = []                       # list of balls
number_balls = 50                # number of balls
speed = 24.0                     # world units per simulation step
infrequent_display = False       # True if ball shown only once/second or so
                                 #    (SPACE flips this), to save CPU time
autopause_period = 64            # How often to pause automatically
                                 # (0 if never)
paused = False                   # True if steps are running
total_collisions = 0             # total collisions counted
total_steps = 0                  # total simulation steps

###########################################################################
### Display (pygame) related stuff                                      ###
###########################################################################

## screen coordinates are (0,0) at upper left
##   x coordinates increase to right
##   y coordinates increase down

import pygame
from pygame.locals import *

# global constants 

White = [250,250,250]
Black = [0,0,0]

# global variables 

screen_size_x = 1200             # nominal screen size, set to fullscreen size later
screen_size_y = 800              # (measured in pixels)

color_scheme = 1                 # 0 = white background, 1 = black background
background_color = Black
line_color = White

pixel_size = 9.0                 # size of pixel in world coordinates
screen_center_x = 0.0            # center of screen in world coordinates
screen_center_y = 0.0            # center of screen in world coordinates

def convert_to_pixels(x,y):
    """
    Return x and y coordinates (given in world units) to screen coords.
    """
    xs = int((x - screen_center_x)/pixel_size + screen_size_x//2)
    ys = int((y - screen_center_y)/pixel_size + screen_size_y//2) - 20
    return (xs,ys)

# procedures

def set_color_scheme(scheme):
    """ 
    Set color scheme. 
      138    0.003    0.000    0.006    0.000 gas.py:449(vdot)
    scheme = 0 (white background) or 1 (black background)
    """
    global line_color, background_color
    if scheme == 0:
        line_color = Black
        background_color = White
    if scheme == 1:
        line_color = White
        background_color = Black

class Ball:
    """ 
    Implements a point/ball
    """

    def __init__(self):
        # Position attributes (floats):
          self.x = random.uniform(world_min_x,world_max_x)
          self.y = random.uniform(world_min_y,world_max_y)
        # Velocity attributes (floats):
          angle = random.uniform(0.0,2* math.pi)    #  direction of motion in radians
          ball_speed = random.uniform(0.0,2.0)      #  speed in world units/step
          self.vx = math.sin(angle) * ball_speed    #  x component of velocity in world units/step
          self.vy = math.cos(angle) * ball_speed    #  y component of velocity in world units/step
        # Radius (int):
          self.radius = int(random.uniform(ball_min_radius,ball_max_radius))  # in world units

        # Mass is proportional to area (float):
          self.mass = float(self.radius**2)         # in (world units)**2 units (arbitrary)
        # Color (in RGB) is randomly chosen
          self.color = [120+int(random.random()*130),     
                        120+int(random.random()*130),
                        120+int(random.random()*130)]

    def draw(self,surface):
        """ 
        Draw ball.         
        """
        global pixel_size

        # center and radius in pixel coords
        (xs,ys) = convert_to_pixels(self.x,self.y)
        rs = int(self.radius/pixel_size)

        # return if ball surely can't be seen
        if xs+rs<0: return
        if xs-rs>screen_size_x: return
        if ys+rs<0: return
        if ys-rs>screen_size_y:return

        # draw colored inside portion
        pygame.draw.circle(surface,
                           self.color,
                           [xs,ys],
                           rs,
                           0)  # width (0 means fill circle)
        # draw circumference
        # pygame.draw.circle(surface,
        #                    line_color,
        #                    [xs,ys],
        #                    rs,
        #                    min(1,rs))  # width

def initialize_screen():
    """ 
    Start pygame and set up display screen. 
    """
    global screen, background, screen_size_x, screen_size_y
    pygame.init()

    # get size of fullscreen display into screen_size_x, screen_size_y
    modes = pygame.display.list_modes()    # defaults to fullscreen
    modes.sort()                           # largest goes last
    screen_size_x,screen_size_y = modes[-1]

    screen = pygame.display.set_mode((screen_size_x, screen_size_y),
                                     pygame.FULLSCREEN )

    # following line is irrelevant for full-screen display
    # pygame.display.set_caption('Gas simulation program')

def initialize_font():
    global font
    font = pygame.font.Font(None, 36)

def initialize_background(color):
    """ 
    Initialize background and fill it with background_color. 
    """
    global background, screen

    background = pygame.Surface(screen.get_size())
    background = background.convert()
    background.fill(color)

def show_background():
    """ 
    Blit background onto screen and show it. 
    """
    global background, screen

    screen.blit(background, (0, 0))
    pygame.display.flip()


def show_text_screen(msgs, **param):
    """ 
    Show a screen of text.
    Return True if user wishes to quit out of this text screen.
    """
    global font, background, paused, number_balls, autopause_period

    def ball_text():
        """
        Draw a text object for the starting number of balls.
        """
        balltext = font.render("Starting number of balls (UP/DOWN):  " +
                               "%d" % number_balls,
                               1,
                               Black)
        balltextpos = balltext.get_rect()
        balltextpos.centerx = background.get_rect().centerx
        balltextpos.centery = 420
        pygame.draw.rect(background, White, balltextpos.inflate(1000, 0))
        background.blit(balltext, balltextpos)

    def autopause_text():
        """
        Draw a text object for the autopause period.
        """
        pausetext = None
        if autopause_period == 0:
            pausetext = font.render("Automatically pause (PGUP/PGDN):  Never",
                                    1,
                                    Black)
        else:
            pausetext = font.render("Automatically pause (PGUP/PGDN):  " +
                                    "every %d timesteps" % autopause_period,
                                    1,
                                    Black)
        pausetextpos = pausetext.get_rect()
        pausetextpos.centerx = background.get_rect().centerx
        pausetextpos.centery = 460
        pygame.draw.rect(background, White, pausetextpos.inflate(1000, 0))
        background.blit(pausetext, pausetextpos)

    initialize_background(White)

    y = 100                               # starting y coord, for first line
    for msg in msgs:
        text = font.render(msg, 
                           1,             # antialias
                           Black)         # color
        textpos = text.get_rect()
        textpos.centerx = background.get_rect().centerx
        textpos.centery = y
        y += 40
        background.blit(text, textpos)

    iswelcome = param.has_key('welcome')
    
    if iswelcome:
        ball_text()
        autopause_text()

    show_background()
    
    # Now wait for user to hit a key, before proceeding.
    # Check if keypress indicates user wants to quit, and if so return True
    while 1:
        for event in pygame.event.get():
            if event.type == QUIT:
                return True
            elif event.type == KEYDOWN and event.key == K_ESCAPE:
                return True
            elif event.type == KEYDOWN and event.key == K_SPACE:
                paused = False
                return False
            elif event.type == KEYDOWN and event.key == K_F1:
                return show_help_screen()
            elif event.type == KEYDOWN and event.key == K_DOWN and iswelcome:
                if number_balls in number_balls_list:
                    i = number_balls_list.index(number_balls)
                else:
                    i = 0
                i = max(0,i-1)
                number_balls = number_balls_list[i]
                ball_text()
                show_background()
            elif event.type == KEYDOWN and event.key == K_UP and iswelcome:
                if number_balls in number_balls_list:
                    i = number_balls_list.index(number_balls)
                else:
                    i = len(number_balls_list)-1
                i = min(len(number_balls_list)-1,i+1)
                number_balls = number_balls_list[i]
                ball_text()
                show_background()
            elif event.type == KEYDOWN and event.key == K_PAGEDOWN and iswelcome:
                if autopause_period == 0:
                    autopause_period = 65536
                elif autopause_period < 10:
                    autopause_period = 0
                else:
                    autopause_period /= 2
                autopause_text()
                show_background()
            elif event.type == KEYDOWN and event.key == K_PAGEUP and iswelcome:
                if autopause_period == 0:
                    autopause_period = 8
                elif autopause_period > 65536:
                    autopause_period = 0
                else:
                    autopause_period *= 2
                autopause_text()
                show_background()

def show_welcome_screen():
    """ 
    Show initial welcome / help screen. 
    """

    msgs = ["gas.py -- gas simulation program",
            " ",
            "F1 at any time shows a help screen",
            "ESC at any time quits the program",
            "SPACE starts the program",
            " ",
            "(c) Ronald L. Rivest. 3/3/2007. Version 1.0.  GPL License.",
            ]
    return show_text_screen(msgs, welcome=True)

def show_help_screen():
    """ 
    Show help screen. 
    Return True if user wishes to quit out of help screen.
    """

    msgs = ["Up arrow increases number of balls",
            "Down arrow decreases number of balls",
            " ",
            "Right arrow increases ball speed",
            "Left arrow decreases ball speed",
            "SPACE toggles whether to frequently update the display",
            "p (un)pauses (with automatic pauses every 500 steps)",
            " ",
            "PAGEUP zooms out",
            "PAGEDOWN zooms in",
            " ",
            "c flips background color (black/white)",
            " ",
            "SPACE proceeds",
            "F1 shows this help screen",
            "ESC quits"
            ]
    return show_text_screen(msgs)

last_step_time = time.time()        # when last simulation step started
last_display_time = time.time()     # when state was last displayed
steps_per_second = 10.0             # initial value only
seconds_per_step = 0.1

def display_balls(balls):
    """ 
    Show all balls.
    """
    global background

    # draw line around world
    line_width = 2
    (xs_min,ys_min) = convert_to_pixels(world_min_x,world_min_y)
    (xs_max,ys_max) = convert_to_pixels(world_max_x,world_max_y)
    pygame.draw.line(background,line_color,(xs_min,ys_min),(xs_min,ys_max),line_width)
    pygame.draw.line(background,line_color,(xs_max,ys_min),(xs_max,ys_max),line_width)
    pygame.draw.line(background,line_color,(xs_min,ys_min),(xs_max,ys_min),line_width)
    pygame.draw.line(background,line_color,(xs_min,ys_max),(xs_max,ys_max),line_width)

    # draw balls
    for b in balls:
        b.draw(background)

def display_label():
    global steps_per_second,line_color,balls,background, last_step_time, paused

    pausestring = "pause"
    if paused:
        pausestring = "resume"
    text1 = font.render("(F1/info, ESC/quit, P/" + pausestring + ")        " +
                        "%d balls         "%len(balls)+
                        "%d collisions"%total_collisions,
                        1,
                        line_color)
    text2 = font.render("Step %d                "%total_steps+
                        "%0.1f simulation steps / second"%steps_per_second,
                        1,
                        line_color)
    text1pos = text1.get_rect()
    text1pos.centerx = background.get_rect().centerx
    text1pos.centery = screen_size_y - 60
    text2pos = text2.get_rect()
    text2pos.centerx = background.get_rect().centerx
    text2pos.centery = screen_size_y - 30

    background.blit(text1,text1pos)
    background.blit(text2,text2pos)

###########################################################################
### USER INPUT                                                          ###
###########################################################################

def handle_user_input():
    """ 
    Detect keypresses, etc., and handle them. 
    Return True iff user requests program to quit
    """
    global number_balls, balls, speed, screen_size_x, screen_size_y
    global color_scheme, infrequent_display, paused, pixel_size
    for event in pygame.event.get():
        if event.type == QUIT:
            return True
        elif event.type == KEYDOWN and event.key == K_ESCAPE:
            return True
        elif event.type == KEYDOWN and event.key == K_F1:
            if show_help_screen():
                return True
        elif event.type == KEYDOWN and event.key == K_DOWN:
            if number_balls in number_balls_list:
                i = number_balls_list.index(number_balls)
            else:
                i = 0
            i = max(0,i-1)
            number_balls = number_balls_list[i]
            balls = balls[:number_balls]
        elif event.type == KEYDOWN and event.key == K_UP:
            if number_balls in number_balls_list:
                i = number_balls_list.index(number_balls)
            else:
                i = len(number_balls_list)-1
            i = min(len(number_balls_list)-1,i+1)
            number_balls = number_balls_list[i]
            while len(balls)<number_balls:
                balls.append(Ball())
        elif event.type == KEYDOWN and event.key == K_LEFT:
            speed /= 1.4
        elif event.type == KEYDOWN and event.key == K_RIGHT:
            speed *= 1.4
            speed = min(speed,screen_size_x/3,screen_size_y/3)
        elif event.type == KEYDOWN and event.key == K_c:
            color_scheme = (1+color_scheme)%2
            set_color_scheme(color_scheme)
        elif event.type == KEYDOWN and event.key == K_p:
            paused = not paused
        elif event.type == KEYDOWN and event.key == K_SPACE:
            infrequent_display = not infrequent_display
        elif event.type == KEYDOWN and event.key == K_PAGEUP:
            pixel_size *= 1.414
        elif event.type == KEYDOWN and event.key == K_PAGEDOWN:
            pixel_size = pixel_size / 1.414
    return False

###########################################################################
### Routines related to ball motion and collision handling              ###
###########################################################################

def dist(b1,b2):
    """ 
    Return distance (in world units) between balls b1 and b2. 
    """
    return (math.sqrt((b1.x-b2.x)**2 + (b1.y-b2.y)**2))

def move_balls():
    """ 
    Move all balls. This is one 'simulation step', aside from
    detecting and handling collisions.
    """
    global balls
    for b in balls:
        move_ball(b)

def move_ball(b):
    """ 
    Move ball b one step, and bounce off edge of world.
    """
    global speed, world_min_x, world_max_x, world_min_y, world_max_y

    b.x += b.vx * speed
    b.y += b.vy * speed

    r = b.radius

    left = world_min_x
    if b.x < left + r:   # bounce off left wall
        b.x = (left + r)+(left+r-b.x)
        b.vx = -b.vx

    right = world_max_x
    if b.x > right - r:  # bounce off right wall
        b.x = (right - r)-(b.x-right+r)
        b.vx = -b.vx

    bottom = world_min_y
    if b.y < bottom + r: # bounce off bottom wall
        b.y = (bottom + r)+(bottom+r-b.y)
        b.vy = -b.vy

    top = world_max_y
    if b.y > top - r:    # bounce off top wall
        b.y = top - r-(b.y-(top-r))
        b.vy = -b.vy

### Vector operations

def vadd(v1,v2):
    """ Return sum of vectors v1 and v2. """
    return [a+b for a,b in zip(v1,v2)]

def vsub(v1,v2):
    """ Return vector v1-v2 """
    return [a-b for a,b in zip(v1,v2)]

def vscale(s,v):
    """ Return product of vector v by the scalar s. """
    return [s*a for a in v]

def vlensq(v):
    """ Return the length squared of vector v. """
    return sum([x*x for x in v])

def vlen(v):
    """ Return the length of vector v. """
    return math.sqrt(vlensq(v))

def vdot(v1,v2):
    """ Return the dot product of vectors v1 and v2. """
    return sum([a*b for a,b in zip(v1,v2)])

def vunit(v):
    """ Return unit vector in same direction as v. """
    length = vlen(v)
    assert length > 0.0
    return vscale(1.0/length,v)

################################################
## COLLISION DETECTION AND COLLISION HANDLING ##
################################################

def detect_collisions(balls):
    """ 
    Detect and handle all ball-to-ball collisions.
    This uses an all-pairs approach, which is OK for a
    reasonable number of balls.  Two balls 'collide'
    if they are overlapping.
    """
    for i in range(len(balls)):
        b0 = balls[i]
        for j in range(i):
            b1 = balls[j]
            d = dist(b0,b1)
            if d<=b0.radius+b1.radius:
                handle_collision(b0,b1)

def handle_collision(b1,b2):
    """ 
    Collide balls b1 and b2.

    Net result is that velocities of b1 and b2 may be changed.
    Detects "false collisions" where balls are close but actually
    moving away from each other; in this case it does nothing.
    (This case is important if balls have just collided but
    haven't really moved apart yet.)
    This routine conserves energy and momentum.
    """
    global total_collisions
    total_collisions += 1

    # ball 1: mass, position, velocity
    m1 = b1.mass
    p1 = [b1.x,b1.y]
    v1 = [b1.vx,b1.vy]

    # ball 2: mass, position, velocity
    m2 = b2.mass
    p2 = [b2.x,b2.y]
    v2 = [b2.vx,b2.vy]

    # center of mass: position, velocity
    pc = vadd(vscale(m1/(m1+m2),p1),vscale(m2/(m1+m2),p2))
    vc = vadd(vscale(m1/(m1+m2),v1),vscale(m2/(m1+m2),v2))

    # return if at same position; can't do anything
    if p1 == p2: return

    u1 = vunit(vsub(p1,pc))      # unit vector towards m1 in cm coords
    w1 = vsub(v1,vc)             # velocity of m1 in cm coords
    z = vdot(w1,u1)              # amount of w1 in direction towards m1
    if z >= 0.0: return          # can't collide; m1 moving away from cm
    r1 = vscale(z,u1)            # velocity of m1 in cm coords along u1
    s1 = vsub(w1,vscale(2.0,r1)) # post-collision velocity in cm coords
    b1.vx, b1.vy = vadd(vc,s1)   # final velocity in global coords

    u2 = vunit(vsub(p2,pc))      # unit vector towards m2 in cm coords
    w2 = vsub(v2,vc)             # velocity of m2 in cm coords
    z = vdot(w2,u2)              # amount of w2 in direction towards m2
    if z >= 0.0: return          # can't collide; m2 moving away from cm
    r2 = vscale(z,u2)            # velocity of m2 in cm coords along u2
    s2 = vsub(w2,vscale(2.0,r2)) # post-collision velocity in cm coords
    b2.vx, b2.vy = vadd(vc,s2)   # final velocity in global coords


###########################################################################
### Main routine / event loop                                           ###
###########################################################################

def main():
        global number_balls, balls, infrequent_display, background_color, total_steps
        global last_display_time, steps_per_second, last_step_time, seconds_per_step
        global paused
        random.seed(17)

        initialize_screen()
        initialize_background(background_color)
        pygame.key.set_repeat(500,300)   # for handling key repeats
        
        initialize_font()
        if show_welcome_screen():
            return

        # Make initial set of balls
        balls = [Ball() for i in range(number_balls)]

        # Record what times we automatically paused, to prevent
        # neverending auto-pause.
        autopaused_on = {}

        # Event loop
        while 1:  
            """ Each iteration of this loop is one 'simulation step'. """
            if handle_user_input(): return
            # auto-pause
            if total_steps > 0 and autopause_period > 0 and \
               total_steps % autopause_period == 0 and \
               not autopaused_on.has_key(total_steps):
                autopaused_on[total_steps] = True
                paused = True

            if not paused:
                total_steps += 1
                elapsed_step_time = time.time() - last_step_time  # since last step computed
                last_step_time = time.time()

                # seconds_per_step is computed as a moving average...
                seconds_per_step = 0.95 * seconds_per_step + 0.05*elapsed_step_time
                steps_per_second = 1.0 / seconds_per_step

                move_balls()
                detect_collisions(balls)


            elapsed_display_time = time.time()-last_display_time
            if not infrequent_display or elapsed_display_time>2.0:
                last_display_time = time.time()
                background.fill(background_color)
                display_balls(balls)
                display_label()
                show_background()

if __name__ == '__main__': 
    import profile
    
    # Parse command line arguments
    # These correspond to defaults at the top ("global variables")
    for i in range(1, len(sys.argv)):
        if sys.argv[i] == '-balls':
            if len(sys.argv) > i+1:
                number_balls = int(sys.argv[i+1])
        elif sys.argv[i] == '-autopause':
            if len(sys.argv) > i+1:
                autopause_period = int(sys.argv[i+1])


    # Profiling slows down the runtime by a significant factor.
    # We've turned it off to help you see the asymptotic behavior.
    # If you want to see profiling of the number of function calls
    # and how much time they take, you can uncomment the profiling line.
    
    #profile.run("main()")
    main()
