# client.py
#
# Last Modified: 10/16/06
# Desc:
#  simple python client implementation, sends well structured
#  request to server, prints well structured return.
#  some code adapted from: http://www.amk.ca/python/howto/sockets/

import socket
import time


# constants
PORT = 6857
HOST = "osprey.csail.mit.edu"
PAD  = "@"
MLEN = 256

#############################################################
# PRIMITIVES FOR LOCAL TESTING
#############################################################

#inputs: starting base, iterations, exp, group
#outputs: variance of the time taken

def D2Bin(n):
    '''convert denary integer n to binary string bStr'''
    bStr = ''
    if n < 0:  raise ValueError, "must be a positive integer"
    if n == 0: return '0'
    while n > 0:
        bStr = str(n % 2) + bStr
        n = n >> 1
    return bStr

def localcalc(start,iters,exp,group):
    [d,n] = [1,1] #INSERT YOUR OWN 'SECRET' EXPONENT and MODULUS HERE
    totalt = 0
    totalsqdiff = 0

    t = []
    for i in range(iters):
        t = t + [0.0]
    
    for i in range(iters):
        t[i] = sam(start+i,d,n)-sam(start+i,exp,n)
        totalt = totalt+t[i]
    avgt = totalt/iters
    for i in range(iters):
        totalsqdiff = totalsqdiff + (t[i]-avgt)**2
    vart = totalsqdiff/iters
    return vart

#returns fake timing for multiplying x and y
def ftime(x,y):
    const = 50
    w = D2Bin(x^y)
    wlen = len(w)
    if wlen < 20:
        for i in range(wlen,20):
            w = '0'+w
    else:
        w = w[wlen-20:wlen]
    fpw = int(w,2)
    fpw = fpw + 0.0
    fpw = fpw/(2**20)
    return const*(1+fpw)

#square and multiply.
#inputs: num, exp, modulus
#outputs: time taken for num**exp % modulus
def sam(a,b,n):
    exp = D2Bin(b)
    explen = len(exp)
    x = 1
    t = 0
    for i in range(explen):
        t = t+ftime(x,x)
        x = (x*x) % n
        if exp[i]=='1':
                t = t+ftime(x,a)
                x = (a*x) % n
    return t

#############################################################
# PRIMITIVES FOR LOCAL TESTING
#############################################################


##################################################################
# CLIENT CALC FUNCTIONS
##################################################################

def networkcalc(start,iters,exp,group):

	# construct request
	msg = str(start) + "|" + str(iters) + "|" + str(exp) + "|" + str(group)
	if len(msg) > MLEN:
		print "WARNING: your values passed to calc are too large!"

	# initialize socket
	sock = socket.socket(socket.AF_INET, socket.SOCK_STREAM)

	# form connection
	sock.connect((HOST, PORT))

	# send request data
	print "sending request:", msg
	sendStr(msg, MLEN, sock)
	print "done sending."

	# receive response
	print "waiting for response..."
	resp = recvStr(MLEN,sock)
	print "received:", trimPadding(resp)

	# close socket
	print "closing socket"
	print ""
	sock.close()

	# return edited response
	return trimPadding(resp)


def trimPadding(s):
	pos = s.find("@")
	if pos > -1:
		return s[0:pos]


##################################################################
# CLIENT COMMUNICATION FUNCTIONS
##################################################################



# if the msg is longer than MSGLEN bytes, it truncates before sending.
# if the msg is shorter than MSGLEN bytes it pads to the appropriate length

def sendStr(rawMsg, MSGLEN, sock):
	msg = rawMsg
	if len(rawMsg) > MSGLEN:
		msg = rawMsg[:MSGLEN]
	elif len(rawMsg) < MSGLEN:
		for i in range(MSGLEN - len(rawMsg)):
			msg = msg + PAD
	
	totalsent = 0
        while totalsent < MSGLEN:
		sent = sock.send(msg[totalsent:])
                if sent == 0:
			print "ERROR: socket connection broken"
			raise RuntimeError, "socket connection broken"
                totalsent = totalsent + sent



def recvStr(MSGLEN,sock):
    msg = ''
    while len(msg) < MSGLEN:
        chunk = sock.recv(MSGLEN-len(msg))
        if chunk == '':
            print "ERROR: socket connection broken"
            raise RuntimeError, "socket connection broken"
        msg = msg + chunk
    return msg

def calc(start,iters,exp,group,local):
        if local:
                return localcalc(start,iters,exp,group)
        else:
                return networkcalc(start,iters,exp,group)
        


##################################################################
##################################################################
# MAIN BODY OF CODE
#   ( put your code below here )
##################################################################
##################################################################


# Sample calculation:
#   - start with value 1234
#   - look at 1000 values from 1234 + 0 to 1234 + 999
#   - use 3 as your guess exponent
#   - use the secret exponent for group 9
#   - do local calculations and testing first

local = 0
var = calc(1234,1000,3,9,local)
print "returned variance was:", var


# To check that your guess is correct, use the following info:
##Gp 0
##modulus =  15581537172818877915
##1234**d % n =  6303898114475531266
##Gp 1
##modulus =  17704010530601115158
##1234**d % n =  15376296443473299616
##Gp 2
##modulus =  17041301202720561625
##1234**d % n =  1728711692918985514
##Gp 3
##modulus =  16451628467532736531
##1234**d % n =  12767638248700099075
##Gp 4
##modulus =  13312857304338297436
##1234**d % n =  5619019465787160016
##Gp 5
##modulus =  17549346984722156199
##1234**d % n =  11488812915583350268
##Gp 6
##modulus =  11316512955175174482
##1234**d % n =  3018189145972402168
##Gp 7
##modulus =  13345998508868831733
##1234**d % n =  12287310500667965992
##Gp 8
##modulus =  12228088983480876126
##1234**d % n =  7150035132948436108
##Gp 9
##modulus =  15482624302785109807
##1234**d % n =  15232922726299810647
##Gp 10
##modulus =  10590287349277160648
##1234**d % n =  3857588523496782616
##Gp 11
##modulus =  10617257947229531295
##1234**d % n =  3656617397942155681
##Gp 12
##modulus =  17914933398068988198
##1234**d % n =  6669903620704508266
##Gp 13
##modulus =  10179347544068221109
##1234**d % n =  6545662251429218137
