#
#  fishPy.py
#  
#  Copyright 2020 Thomas E. Padgett
#  Contact: <thomas.e.padgett@outlook.com>
#  
#  This program is free software; you can redistribute it and/or modify
#  it under the terms of the GNU General Public License as published by
#  the Free Software Foundation; either version 2 of the License, or
#  (at your option) any later version.
#  
#  This program is distributed in the hope that it will be useful,
#  but WITHOUT ANY WARRANTY; without even the implied warranty of
#  MERCHANTABILITY or FITNESS FOR A PARTICULAR PURPOSE.  See the
#  GNU General Public License for more details.
#  
#  You should have received a copy of the GNU General Public License
#  along with this program; if not, write to the Free Software
#  Foundation, Inc., 51 Franklin Street, Fifth Floor, Boston,
#  MA 02110-1301, USA.
#

#%%
############################################################################### PREAMBLE
import numpy as np
import random
import scipy
from scipy import io
import scipy.interpolate as spint
import matplotlib.pyplot as plt
from mpl_toolkits.mplot3d import Axes3D
import sys
import time
import csv

start_time = time.perf_counter()
print("\n \n ------------------------------------------------------- ")
print("                 ------- fishPy -------               ")
print("              ----- Initialising Code -----           ")
print(" ------------------------------------------------------- " + "\n")


#%%
############################################################################### DEBUG MODE INFO
PLOT = True
DEBUG = False
WRITE = True
CSVWRITE = True

debug = {}
# Define here if DEBUG mode used
debug['creations'] = np.array([[14.640, 1. , 1.461],[14.640, 1., 1.461],[14.755, 1. , 1.467],[25.0,1.5,1.0],[25.0,2.5,1.0]])
debug['plotColours'] = ['b', 'r', 'g', 'y', 'm']
debug['bodylengths'] = [0.15,0.2,0.15,0.15,0.15]
debug['fishNumReq'] = 2

#%% RULES
############################################################################### RULES

followFlowSwitch = True
minMaxEnergySwitch = True
randomWalkSwitch = True 
obAvoidanceSwitch = True
colAvoidanceSwitch = False
memorySwitch = True
tkeAvoidanceSwitch = True

scaledVels = True
burstSwitch = True

#%%
############################################################################### USER INPUTS
# Define required data.
data = {}
# Define creation zone via below six variables
data['creationZoneXmin'] = 25.0           # Lower X limit of creation zone bounding box
data['creationZoneXmax'] = 27.5           # Upper X limit of creation zone bounding box
data['creationZoneYmin'] = 3.0           # Lower Y limit of creation zone bounding box
data['creationZoneYmax'] = 5.0           # Upper Y limit of creation zone bounding box
data['creationZoneZmin'] = 0.0             # Lower Z limit of creation zone bounding box
data['creationZoneZmax'] = 2.0            # Upper Z limit of creation zone bounding box

# Define target zone via below six variables
data['targetZoneXmin'] = 0.0            # Lower X limit of target zone bounding box
data['targetZoneXmax'] = 5.0           # Upper X limit of target zone bounding box
data['targetZoneYmin'] = 0.0        # Lower Y limit of target zone bounding box
data['targetZoneYmax'] = 10.0           # Upper Y limit of target zone bounding box
data['targetZoneZmin'] = -10.0          # Lower Z limit of target zone bounding box
data['targetZoneZmax'] = 10.0           # Upper Z limit of target zone bounding box

# Modify below as per your requirements.
data['fishTimestep'] = 0.1              # This is the fish time step (s)
data['bodylength_mean'] = 0.5          # Mean bodylength of fish sample (Normal Dist) (m)
data['bodylength_deviation'] = 0.05      # Standard deviation of fish sample (m)
data['fishNumReq'] = 5                     # Number of fish
data['Tmax'] = 1000                     # Maximum number of timesteps

# No changes in below values
if DEBUG == True:
    data['fishNum'] = debug['fishNumReq']
else:
    data['fishNum'] = data['fishNumReq']

data['SensoryRange'] = 1.0              # This is the fish sensory range for a 1 second fish timestep (units -> bodylength)
data['repulsionDist'] = (data['bodylength_mean']+(2*data['bodylength_deviation']))/2 # Repulsion distance (m)
data['fallbackMax'] = 100               # Maximum number of fallbacks per individual
data['passabilityThreshold'] = 0.1      # Threshold defining minimum geometry
data['rho_w'] = 1000                    # Density of water (kg/m3)
data['rho_f'] = 1081                    # Density of fish sans bladder (kg/m3)
data['mu'] = 0.001                      # Dynamic viscosity of water Pa.s
data['AtmPressure'] = 101325            # Atmospheric pressure Pa
data['waterTemperature'] = 10			# Nominal value of water temperature (Celsius)
distMoved = list(range(0, data['fishNum'])) # Initialises the distMoved list for later use.

    
#%%
############################################################################### ENVIRONMENT DATA
# Define which environmental data used:
envData = scipy.io.loadmat('domains/veriSetB_output10cm.mat')

dom = {}
# Define the resolution:
dom['spX'] = 0.1 # grid spacing in x (m)

if 'G3D' not in envData:
    if 'POR3D' in envData:
        envData['G3D'] = envData['POR3D']
    elif 'VF3D' in envData:
        envData['G3D'] = envData['VF3D']

[dom['nx'], dom['ny'], dom['nz']] = envData['U3D'].shape
dom['spY'] = dom['spX'] # grid spacing in y (m)
dom['spZ'] = dom['spY'] # grid spacing in z (m)

if 'X3D' in envData:
    dom['maxX'] = np.amax(envData['X3D'])
    dom['maxY'] = np.amax(envData['Y3D'])
    dom['maxZ'] = np.amax(envData['Z3D'])

    dom['minX'] = np.amin(envData['X3D'])
    dom['minY'] = np.amin(envData['Y3D'])
    dom['minZ'] = np.amin(envData['Z3D'])    
    
    dom['vectX'] = np.linspace(dom['minX'],dom['maxX'],num=dom['nx']) # Vector of x
    dom['vectY'] = np.linspace(dom['minY'],dom['maxY'],num=dom['ny']) # Vector of y
    dom['vectZ'] = np.linspace(dom['minZ'],dom['maxZ'],num=dom['nz']) # Vector of z
    
else:
    # We assume X,Y,Z all start at zero.
    dom['minX'] = dom['spX']
    dom['minY'] = dom['spY']
    dom['minZ'] = dom['spZ'] 

    dom['maxX'] = (dom['nx'])*dom['spX']
    dom['maxY'] = (dom['ny'])*dom['spY']
    dom['maxZ'] = (dom['nz'])*dom['spZ']

    dom['vectX'] = np.linspace(dom['minX'],dom['maxX'],num=dom['nx'])
    dom['vectY'] = np.linspace(dom['minY'],dom['maxY'],num=dom['ny'])
    dom['vectZ'] = np.linspace(dom['minZ'],dom['maxZ'],num=dom['nz'])
    
    envData['X3D'],envData['Y3D'],envData['Z3D'] = np.meshgrid(dom['vectX'],dom['vectY'],dom['vectZ'], indexing='xy') # Construct the x,y,z coords of each node
    envData['X3D'] = np.swapaxes(envData['X3D'],0,1)
    envData['Y3D'] = np.swapaxes(envData['Y3D'],0,1)
    envData['Z3D'] = np.swapaxes(envData['Z3D'],0,1)
        
dom['midX'] = dom['maxX']/2
dom['midY'] = dom['maxY']/2
dom['midZ'] = dom['maxZ']/2
    

# Convert from passability to DEM
print('--- Creating DEM file... \n')
if 'POR3D' in envData:
    envData['DEM'] = np.zeros([dom['nx'],dom['ny']])
    for i in range(0,dom['nx']):
        for j in range(0,dom['ny']):
            envData['DEM'][i][j] = (dom['nz'] - sum(envData['POR3D'][i][j][:]))*dom['spZ']
else:
    envData['DEM'] = np.zeros([dom['nx'],dom['ny']])
    for i in range(dom['nx']):
        for j in range(dom['ny']):
            k = 0
            while envData['G3D'][i][j][k] < 1 and k < dom['nz']-1:
                k += 1
                #print(k)
                tempVar = envData['Z3D'][i][j][k]
                #print(tempVar)
                envData['DEM'][i][j] = tempVar

        
#%%
############################################################################### FUNCTION DEFINITIONS
    def determineCreationLocation(ID,thresholdDistance):
        # This function determines a suitable location to create a fish
        # based on user defined bounding box
        
        # If DEBUG mode enabled, just use pre-defined creation points
        if DEBUG == True:
            x = debug['creations'][ID][0]
            y = debug['creations'][ID][1]
            z = debug['creations'][ID][2]
            
        # Otherwise try to find a suitable point:
        else:
            # Predefine variables
            passabilityCheck = False
            thresholdCheck = False
            k = 0
            kmax = 10
            
            while passabilityCheck == False or thresholdCheck == False or k < kmax:
                # Pick random point within bounding box
                x = random.uniform(data['creationZoneXmin'],data['creationZoneXmax'])
                y = random.uniform(data['creationZoneYmin'],data['creationZoneYmax'])
                z = random.uniform(data['creationZoneZmin'],data['creationZoneZmax'])
        
                # Check passability of point
                passability = interrogatePassability(x,y,z)
                # If passability is high (water) then set passabilityCheck to True
                if passability < data['passabilityThreshold']:
                    passabilityCheck = False
                    print('passabilityCheck failed')
                else:
                    passabilityCheck = True
                    print('passabilityCheck true')
                
                    # If passability is okay and collisions are enabled, then 
                    # check that point isn't too close to other fish already 
                    # created. Make arrays of zeros with length ID
                    if colAvoidanceSwitch == False or ID == 0:
                        thresholdCheck = True
                        print('thresholdCheck true')
                    else:
                        dist = np.zeros(ID)
                        distCheck = np.zeros(ID)
                
                        # For all other fish created; determine the distance between
                        # creation points. If distance below threshold flag in distCheck
                        for j in range(ID):
                            dist[j] = calcDistance(np.array([x,y,z]),fishes[j].coordsCentroid)
                            if dist[j] != 0 and dist[j] < thresholdDistance:
                                distCheck[j] = 1
                    
                        # If all distances are okay sum(distCheck) should equal 0
                        if np.sum(distCheck) == 0:
                            thresholdCheck = True
                            
                        else:
                            pass                
                    
                k += 1
            
        loc = np.array([x,y,z])
        return loc
    
    def checkCreationLocation(ID,thresholdDistance,x,y,z):
        # Checks the accetpability of a creation point, x,y,z
        passabilityCheck = False
        thresholdCheck = False
        
        if colAvoidanceSwitch == False:
            thresholdCheck = True
        else:
            dist = np.zeros(ID)
            distCheck = np.zeros(ID)
                
            # For all other fish created; determine the distance between
            # creation points. If distance below threshold flag in distCheck
            for j in range(ID):
                dist[j] = calcDistance(np.array([x,y,z]),fishes[j].coordsCentroid)
                if dist[j] != 0 and dist[j] < thresholdDistance:
                    distCheck[j] = 1
                    
            # If all distances are okay sum(distCheck) should equal 0
            if np.sum(distCheck) == 0:
                thresholdCheck = True
        
        # Check passability of point
        passability = interrogatePassability(x,y,z)
        
        # If passability is high (water) then set passabilityCheck to True
        if passability > data['passabilityThreshold']:
            passabilityCheck = True
        
        if passabilityCheck == True and thresholdCheck == True:
            return True
        else:
            return False

    def interrogateVelocity3D(x,y,z):
        # Function accepts coords (XYZ) and returns an array of UVW.
        vel = np.array([0.0,0.0,0.0])
        
        linearInterpolationU = spint.RegularGridInterpolator([dom['vectX'],dom['vectY'],dom['vectZ']], envData['U3D'], method='linear', bounds_error=True)
        vel[0] = linearInterpolationU([x,y,z])
        
        linearInterpolationV = spint.RegularGridInterpolator([dom['vectX'],dom['vectY'],dom['vectZ']], envData['V3D'], method='linear', bounds_error=True)
        vel[1] = linearInterpolationV([x,y,z])
    
        linearInterpolationW = spint.RegularGridInterpolator([dom['vectX'],dom['vectY'],dom['vectZ']], envData['W3D'], method='linear', bounds_error=True)
        vel[2] = linearInterpolationW([x,y,z])

        return vel
    
    def interrogateVelocity2D(x,y,z):
        # Function accepts coords (XYZ) and returns an array of UVW.
        vel = np.array([0.0,0.0,0.0])
        
        linearInterpolationU = spint.RegularGridInterpolator([dom['vectX'],dom['vectY'],dom['vectZ']], envData['U3D'], method='linear', bounds_error=True)
        vel[0] = linearInterpolationU([x,y,z])
        
        linearInterpolationV = spint.RegularGridInterpolator([dom['vectX'],dom['vectY'],dom['vectZ']], envData['V3D'], method='linear', bounds_error=True)
        vel[1] = linearInterpolationV([x,y,z])

        return vel
    
    def interrogatePassability(x,y,z):
        # Function accepts coords (XYZ) and returns passability at location.
        linearInterpolationPass = spint.RegularGridInterpolator([dom['vectX'],dom['vectY'],dom['vectZ']], envData['G3D'], method='linear')
        passability = linearInterpolationPass([x,y,z])
        return passability
    
    def interrogateInDomain(x,y,z):
        # Function that accepts (xyz) coords and returns True if the coords
        # are within the domain.
        xState = 0
        yState = 0
        zState = 0
        
        if x <= dom['maxX'] and x >= dom['minX']:
            xState = 1
        if y <= dom['maxY'] and y >= dom['minY']:
            yState = 1
        if z <= dom['maxZ'] and z >= dom['minZ']:
            zState = 1
            
        if xState+yState+zState == 3:
            return True
        else:
            return False
        
    def interrogateTKE(x,y,z):
        # Function accepts coords (XYZ) and returns an array of UVW.
        TKE = 0.0
        linearInterpolationTKE = spint.RegularGridInterpolator([dom['vectX'],dom['vectY'],dom['vectZ']], envData['K3D'], method='linear', bounds_error=True)
        TKE = linearInterpolationTKE([x,y,z])
        return TKE
        
    def calcMagnitude(vector):
        # Accepts UVW and returns the magnitude.
        U = vector[0]
        V = vector[1]
        W = vector[2]
        Mag = np.sqrt(U**2 + V**2 + W**2)
        return Mag

    def calcUnit(vector):
        # Function accepts a vector and returns the unit vector
        Mag = calcMagnitude(vector)
        if Mag == 0.0:
            print('\n --- Warning: Zero magnitude detected.')
            print(' --- Warning: Impossible to define unit vector.')
            return vector
        else:
            unit = vector/Mag
            return unit
        
    def calcDistance(array1,array2):
        a = (array1[0]-array2[0])**2 + (array1[1]-array2[1])**2 + (array1[2]-array2[2])**2
        mag = np.sqrt(a)
        return mag
    
    def calcVector(array1, array2):
        vector = [array1[0] - array2[0], array1[1]-array2[1], array1[2]-array2[2]]
        return vector
    
    def initiativeFunction(distMoved, NumFish):
        indices = list(range(0, NumFish))
        A = zip(indices,distMoved)
        distMovedSorted = sorted(A, key=lambda x: x[1],reverse = True)
        initiative = [list(t) for t in zip(*distMovedSorted)][0]
        return initiative
    
    def loadingScreen():
        c = "                                                                 " 
        sys.stdout.write('\r' + c)      #This writes over anything in the print space with a large blank bit
        for i in range(0,5):
            b = " --- Simulating fish movements" + "." * i
            sys.stdout.write('\r' + b + '\r')
            sys.stdout.flush()
            time.sleep(0.1)

    def calcReynoldsNumber(U,L):
        Re = (data['rho_w']*U*L)/data['mu']
        return Re
    
    def calcAuxiliaryVector(vector, angle):
        # Function which returns a vector that is at input angle from the 
        # input vector in the same 2D plane. 
        angle = angle*np.pi/180
        c = np.cos(angle)
        s = np.sin(angle)
        A = np.array([[c, 0, s], [0, 1, 0], [-s, 0, c]])
        output = np.dot(A,vector)
        return output
    
    def angleOrthogonalChecker(X,Y,Z):
        orthogonal = False
        XYangle = np.arccos(np.dot(X,Y)/(calcMagnitude(X)*calcMagnitude(Y)))*(180/np.pi)
        YZangle = np.arccos(np.dot(Y,Z)/(calcMagnitude(Y)*calcMagnitude(Z)))*(180/np.pi)
        XZangle = np.arccos(np.dot(X,Z)/(calcMagnitude(X)*calcMagnitude(Z)))*(180/np.pi)
        #print(XYangle,YZangle,XZangle)
        if round(XYangle,10) == 90.0 and round(YZangle,10) == 90.0 and round(XZangle,10) == 90.0:
            orthogonal = True
        return orthogonal
    
    def polarise(A,B):
        # sets A to 1 or 0 based on threshold B
        if A > B:
            A = 1
        else:
            A = 0
        return A
    
    def findMinIndices(array):
        indicies = []
        minimum = min(array)
        for index, element in enumerate(array):
            if minimum == element: # check if this element is the minimum_value
                    indicies.append(index) # add the index to the list if it is

        return indicies

    def calcPointPlaneDistance(point,plane):
        # Returns the distance between the point(x,y,z) and plane(A,B,C,D)
        Ax = plane[0]*point[0] 
        By = plane[1]*point[1]
        Cz = plane[2]*point[2]
        D = plane[3]
        denom = np.sqrt(plane[0]**2 + plane[1]**2 + plane[2]**2)
        
        dist = (np.absolute(Ax + By + Cz + D))/denom
        
        return dist
    
#%%   
############################################################################### CLASS DEFINITIONS

class fish(object):
    
    def __init__(self, UniqueID, x, y, z):
        
        self.id = UniqueID          # ID number
        
        self.x = x
        self.y = y
        self.z = z        
        
        self.coordsCentroid = np.array([self.x, self.y, self.z])
        self.creation = np.array([self.x, self.y, self.z])
        self.history = self.coordsCentroid
        self.history = np.vstack([self.history, self.coordsCentroid])
        
        # Define distance moved history
        self.distMovedHistory = 0.0
        
        # Randomly selected bodylength from Gaussian distribution.
        if DEBUG == False:
            self.bodylength = np.random.normal(data['bodylength_mean'],data['bodylength_deviation'])
        else:
            self.bodylength = debug['bodylengths'][self.id]
        
        #Assign width and height from relationship to length.
        self.bodywidth = (0.08571*self.bodylength) + (1.42885/1000)
        self.bodyheight = (0.21875*self.bodylength) - (1.875/1000)
        
        # Sensory ovoid
        self.SQDx = self.bodylength * data['SensoryRange']
        self.SQDy = self.bodylength * data['SensoryRange']
        self.SQDz = self.bodylength * data['SensoryRange'] * 0.25
        
        # Swimming capabilities (calculated using SWIMIT at 5 deg temp)
        self.swimBurst = -40.38*self.bodylength**2 + 15.7082*self.bodylength
        self.swimSust = -10.95*self.bodylength**2 + 10.1036*self.bodylength
        # N.b. Prolonged swim speed not given in SWIMIT.

        #Swimming capabilities (Scruton et al 1998)
        self.swimBurst = 0.305 + 0.061*100*self.bodylength - 0.05742
        self.swimSust = 0.048*100*self.bodylength + 0.02*data['waterTemperature']
        
        # Wetted area calculation from Haefner and Bowen (2002).
        self.wettedArea = 0.465*self.bodylength**2.11
        
        # Predefine total energy expended
        self.totalEnergy0 = 0
        
        # Predefine total distance moved
        self.distMovedTotal = 0
        
        # Predefine total time taken
        self.totalTime = 0
        
        # Predefine axis / direction fish is pointing (directly upstream).        
        self.heading = calcUnit(-interrogateVelocity3D(self.coordsCentroid[0], self.coordsCentroid[1], self.coordsCentroid[2]))
   
        # Define mass of individual based on bodylength
        self.mass = (10**(-4.867+2.96*np.log10(self.bodylength*1000)))/1000
        
        # Define Vn @ atm (FOR BUOYANCY)
        #self.VnAtm = self.mass*(data['rho_f'] - data['rho_w'])/(data['rho_f']*data['rho_w'])
        #self.VolConstant = self.VnAtm * data['AtmPressure']
        #self.Vs = self.VolConstant/(data['AtmPressure']+(data['rho_w']*9.81*(dom['maxZ']-self.z)))
        
        #Minimise or maximise the energy path taken? (True = minimise)
        self.minEnergy = -1
        self.minEnergyIt = 0
        self.minEnergyTimeMax = 2.0 
        self.minEnergyItMax = self.minEnergyTimeMax/data['fishTimestep']
        self.maxEnergyThreshold = 0.05 #minimum velocity magnitude below which
        # individual will search for maximum energy path.
        self.minEnergyThreshold = 0.4 # minimum velocity magnitude below which
        # individual will stop searching for minimum energy path.
        self.gradThresholdMax = -1.25 # threshold of neg vel gradient causing 
        # individual to search for max energy.
        self.gradThresholdMin = -0.35 # threshold of neg vel gradient causing 
        # individual to stop searching for min energy.
        
        # Weights that define movement weighted averages
        #followFlow,minMaxEnergy,randomWalk,obAvoidance,colAvoidance, memory, tkeAvoidance
        self.weight_followFlow = 0.2
        self.weight_minMaxEnergy = 0.1
        self.weight_randWalk = 0.05 
        self.weight_obAvoid = 0.15 
        self.weight_colAvoid = 0.2 
        self.weight_memory = 0.2
        self.weight_tkeAvoid = 0.1
        
        self.movementWeights = ([self.weight_followFlow,
                                 self.weight_minMaxEnergy,
                                 self.weight_randWalk,
                                 self.weight_obAvoid,
                                 self.weight_colAvoid,
                                 self.weight_memory,
                                 self.weight_tkeAvoid])
        
        # Ratio to reduce movement vector if location otherwise unsuitable.
        self.reduceStep = 0.67
        
        # Boolean defining if fish is falling back.
        self.fallback = False
        self.fallbackIt = 0
        self.maxFallbackIt = 5/data['fishTimestep'] 
        self.fallbackCount = 0
        self.escape = False
        # Reason why the individual has failed
        self.failReason = 'none'
        
        # tke
        self.tkeThreshold = 0.35
        
        # Memory stuff
        self.memoryTime = 10.0 # Time over which individual 'remembers' (s)
        self.numMemory = self.memoryTime/data['fishTimestep'] # Number of 
        # timesteps over which memory is stored
        self.dirMemory = np.zeros([int(round(self.numMemory,1)),3])
        self.dirMemory[0,:] = self.heading
        
    def findVertexLocations(self):
        self.X_prime = calcUnit(self.heading)
        AuxVec = calcAuxiliaryVector(self.X_prime, 30)
        self.Y_prime = calcUnit(np.cross(self.X_prime, AuxVec))
        self.Z_prime = calcUnit(np.cross(self.X_prime, self.Y_prime))
        if self.Z_prime[2] < 0.0:
            self.Z_prime = -self.Z_prime
        
        print(self.X_prime, self.Y_prime, self.Z_prime)
        check = angleOrthogonalChecker(self.X_prime, self.Y_prime, self.Z_prime)
        # returns True if orthogonal.
        if check == False:
            print('--- Warning: Possible inaccurate conversion to fish axis')
        
        # Defines the locations of the nose, nose(Sensory ovoid), and a mid 
        # point between the two.
        Origin_prime = self.coordsCentroid
        self.coordsNose = Origin_prime + self.X_prime*self.bodylength/2
        self.coordsNoseSO = self.coordsNose + self.X_prime*self.SQDx
        self.coordsNoseSOMid = self.coordsNose + self.X_prime*self.SQDx*0.5
        
        self.coordsTail = Origin_prime - self.X_prime*self.bodylength/2
        self.coordsTailSO = self.coordsTail - self.X_prime*self.SQDx
        self.coordsTailSOMid = self.coordsTail - self.X_prime*self.SQDx*0.5
        
        self.coordsLeft = Origin_prime + self.Y_prime*self.bodywidth/2
        self.coordsLeftSO = self.coordsLeft + self.Y_prime*self.SQDy
        self.coordsLeftSOMid = self.coordsLeft + self.Y_prime*self.SQDy*0.5
        
        self.coordsRight = Origin_prime - self.Y_prime*self.bodywidth/2
        self.coordsRightSO = self.coordsRight - self.Y_prime*self.SQDy
        self.coordsRightSOMid = self.coordsRight - self.Y_prime*self.SQDy*0.5
        
        self.coordsTop = Origin_prime + self.Z_prime*self.bodyheight/2
        self.coordsTopSO = self.coordsTop + self.Z_prime*self.SQDz
        self.coordsTopSOMid = self.coordsTop + self.Z_prime*self.SQDz*0.5
        
        self.coordsBottom = Origin_prime - self.Z_prime*self.bodyheight/2
        self.coordsBottomSO = self.coordsBottom - self.Z_prime*self.SQDz
        self.coordsBottomSOMid = self.coordsBottom - self.Z_prime*self.SQDz*0.5
        
    def findFrontLocations(self, ratio):
        # Front locations used for obstacle avoidance 
        self.coordsLeftFront = self.coordsLeftSOMid + self.X_prime*self.SQDx*ratio
        self.coordsRightFront = self.coordsRightSOMid + self.X_prime*self.SQDx*ratio
        self.coordsTopFront = self.coordsTopSOMid + self.X_prime*self.SQDx*ratio
        self.coordsBottomFront = self.coordsBottomSOMid + self.X_prime*self.SQDx*ratio
        self.coordsNoseSOFront = self.coordsNose + self.X_prime*self.SQDx*ratio

    def domainLimitsCentroid(self):
        # Function ensures individual stays within the domain limits.
        if self.coordsCentroid[0] > dom['maxX']:
            self.coordsCentroid[0] = dom['maxX']
    
        if self.coordsCentroid[0] < dom['minX']:
            self.coordsCentroid[0] = dom['minX']
          
        if self.coordsCentroid[1] > dom['maxY']:
            self.coordsCentroid[1] = dom['maxY']
    
        if self.coordsCentroid[1] < dom['minY']:
            self.coordsCentroid[1] = dom['minY']
        
        if self.coordsCentroid[2] > dom['maxZ']:
            self.coordsCentroid[2] = dom['maxZ']
    
        if self.coordsCentroid[2] < dom['minZ']:
            self.coordsCentroid[2] = dom['minZ']
        
            
    def domainLimits(self):
        EXIT_domainLimits = False
        while EXIT_domainLimits == False:
            
            maxZ = max(self.coordsNose[2],self.coordsTail[2],self.coordsLeft[2],self.coordsRight[2],self.coordsTop[2],self.coordsBottom[2],self.coordsCentroid[2])
            minZ = min(self.coordsNose[2],self.coordsTail[2],self.coordsLeft[2],self.coordsRight[2],self.coordsTop[2],self.coordsBottom[2],self.coordsCentroid[2])
            maxY = max(self.coordsNose[1],self.coordsTail[1],self.coordsLeft[1],self.coordsRight[1],self.coordsTop[1],self.coordsBottom[1],self.coordsCentroid[1])
            minY = min(self.coordsNose[1],self.coordsTail[1],self.coordsLeft[1],self.coordsRight[1],self.coordsTop[1],self.coordsBottom[1],self.coordsCentroid[1])
            maxX = max(self.coordsNose[0],self.coordsTail[0],self.coordsLeft[0],self.coordsRight[0],self.coordsTop[0],self.coordsBottom[0],self.coordsCentroid[0])
            minX = min(self.coordsNose[0],self.coordsTail[0],self.coordsLeft[0],self.coordsRight[0],self.coordsTop[0],self.coordsBottom[0],self.coordsCentroid[0])
            
            if maxZ > dom['maxZ']:
                self.coordsNose[2] -= 0.001
                self.coordsTail[2] -= 0.001
                self.coordsLeft[2] -= 0.001
                self.coordsRight[2] -= 0.001
                self.coordsTop[2] -= 0.001
                self.coordsBottom[2] -= 0.001
                self.coordsCentroid[2] -= 0.001
            
            if maxY > dom['maxY']:
                self.coordsNose[1] -= 0.001
                self.coordsTail[1] -= 0.001
                self.coordsLeft[1] -= 0.001
                self.coordsRight[1] -= 0.001
                self.coordsTop[1] -= 0.001
                self.coordsBottom[1] -= 0.001
                self.coordsCentroid[1] -= 0.001
            
            if maxX > dom['maxX']:
                self.coordsNose[0] -= 0.001
                self.coordsTail[0] -= 0.001
                self.coordsLeft[0] -= 0.001
                self.coordsRight[0] -= 0.001
                self.coordsTop[0] -= 0.001
                self.coordsBottom[0] -= 0.001
                self.coordsCentroid[0] -= 0.001
            
            if minZ < dom['minZ']:
                self.coordsNose[2] += 0.001
                self.coordsTail[2] += 0.001
                self.coordsLeft[2] += 0.001
                self.coordsRight[2] += 0.001
                self.coordsTop[2] += 0.001
                self.coordsBottom[2] += 0.001
                self.coordsCentroid[2] += 0.001
            
            if minY < dom['minY']:
                self.coordsNose[1] += 0.001
                self.coordsTail[1] += 0.001
                self.coordsLeft[1] += 0.001
                self.coordsRight[1] += 0.001
                self.coordsTop[1] += 0.001
                self.coordsBottom[1] += 0.001
                self.coordsCentroid[1] += 0.001
            
            if minX < dom['minX']:
                self.coordsNose[0] += 0.001
                self.coordsTail[0] += 0.001
                self.coordsLeft[0] += 0.001
                self.coordsRight[0] += 0.001
                self.coordsTop[0] += 0.001
                self.coordsBottom[0] += 0.001
                self.coordsCentroid[0] += 0.001
            
            if maxZ < dom['maxZ'] and maxY < dom['maxY'] and maxX < dom['maxX'] and minZ > dom['minZ'] and minY > dom['minY'] and minX > dom['minX']:
                EXIT_domainLimits = True

    def calcHeading(self):
        # Find the local velocity vector, finds its opposite, create a unit vector
        # Direction of local velocity.
        [localVel, a, b] = self.followFlow()
        
        # Direction of last movement.
        unitLastMove = calcUnit(calcVector(self.coordsCentroid, self.history[-1,:]))
        
        if sum(unitLastMove) == 0.0:
            unitLastMove = localVel
        
        # contribution of memory
        alpha_mem = 1.5
        mem = self.memoryRule()
        
        # contribution of last movement direction
        alpha_cH = 0.5
        
        # Calculate new heading
        self.heading = calcUnit((alpha_cH*unitLastMove) + localVel + alpha_mem*mem)

    def findNodes(self):
        X_prime = calcUnit(self.heading)
        AuxVec = calcAuxiliaryVector(X_prime, 30)
        Y_prime = calcUnit(np.cross(X_prime, AuxVec))
        Z_prime = calcUnit(np.cross(X_prime, Y_prime))
        check = angleOrthogonalChecker(X_prime, Y_prime, Z_prime)
        # returns True if orthogonal.
        if check == True:
            pass
        else:
            print('--- Warning: Failed to convert to fish axis')
        
        Origin_prime = self.coordsCentroid
        self.coordsNose = Origin_prime + X_prime*self.bodylength/2
        self.coordsTail = Origin_prime - X_prime*self.bodylength/2
        self.coordsLeft = Origin_prime + Y_prime*self.bodywidth/2
        self.coordsRight = Origin_prime - Y_prime*self.bodywidth/2
        self.coordsTop = Origin_prime + Z_prime*self.bodyheight/2
        self.coordsBottom = Origin_prime - Z_prime*self.bodyheight/2

#%%
############################################################################### RULES

    def runRules(self,printer,i):
        # Run rules
        # j is current fish 
        runResults = list()
        
        # Rule 1 returns unit vector (direction) as a result of the local 
        # velocity field direction as well as the magnitude of local average 
        # velocity
        followFlow, avgMagVel, avgVecVel = self.followFlow()
        if followFlowSwitch == False:
            followFlow = np.array([0,0,0])
        runResults.append(followFlow)
        
        if minMaxEnergySwitch == True:
            minMaxEnergy = self.minMaxEnergyFunc(avgMagVel)
        else:
            minMaxEnergy = np.array([0.,0.,0.])
        runResults.append(minMaxEnergy)
        
        # Rule 3 returns a unit vector as a response to an obstacle in front 
        # of the individual.
        if obAvoidanceSwitch == True:
            obAvoid = self.obAvoidRule()
        else:
            obAvoid = np.array([0,0,0])
        runResults.append(obAvoid)
        
        # Rule 4 returns a unit vector as a response to other fish.
        if colAvoidanceSwitch == True:
            colAvoid = self.colAvoidRule()
        else:
            colAvoid = np.array([0,0,0])
        runResults.append(colAvoid)
        
        # Rule 5 returns a unit vector as a response to TKE minimisation.
        if tkeAvoidanceSwitch == True:
            tkeAvoid = self.tkeAvoid2()
        else:
            tkeAvoid = np.array([0,0,0])
        runResults.append(tkeAvoid)
        
        if randomWalkSwitch == True:
            walk = self.randomWalk()
        else:
            walk = np.array([0,0,0])
        runResults.append(walk)
        
        if memorySwitch == True:
            mem = self.memoryRule()
        else:
            mem = np.array([0,0,0])
        runResults.append(mem)
        
        if printer == True:
            print('\n \n - Timestep is            %s' % i)
            print(' - Fish ID is            %s' % self.id)
            print(' - energy search is      %s' % self.minEnergy)
            print(' - followFlow is         %s' % followFlow)
            print(' - minMaxEnergy is       %s' % minMaxEnergy)
            print(' - randomWalk is         %s' % walk)
            print(' - ObAvoidance is        %s' % obAvoid)
            print(' - ColAvoidance is       %s' % colAvoid)
            print(' - memory is             %s' % mem)
            print(' - tkeAvoidance is       %s \n' % tkeAvoid)
        
        return runResults, avgMagVel, avgVecVel
    
    def deterResponse(self, runResults):
        # This function decides on the response from the list of responses.
        # Multiply the resulting unit vectors by the weights.
        # First calculate the weightings
        self.calcWeightings()
        
        for i in range(len(runResults)):
            for j in range(len(runResults[i])):
                    runResults[i][j] = runResults[i][j]*self.movementWeights[i]
                    
        Response = calcUnit(sum(runResults))
        self.Response = Response #Debugging Store
        
        # Make sure that the final direction is not dominated by z direction.
        # This should only occurs if the summation of all rules results in cancelling 
        # out of X and Y components. 
        if np.abs(Response[2]) > 0.3:
            Response[2] = 0.0
        
        Response = self.wallCheck(Response)
        
        Response = calcUnit(Response)   
        print(' - Final movement direction is %s' % Response)           
        return Response
    
    def wallCheck(self, Response):
        # This function checks the behavioural response to check for and
        # remove movement into walls. The function accepts an array (1x3) 
        # and returns an array (1x3)
        
        # Check the location of each physical node, if one is impassable, 
        # determine the unit vector and add it to the response to bias the 
        # response away from the impassable zone. Overriding the behavioural 
        # ruleset.
        if self.checkLocation(self.coordsNose) == False:
            unit = calcUnit(calcVector(self.coordsCentroid, self.coordsNose))
            Response =+ unit
            print(" - Warning: Detected impassable at nose")

        if self.checkLocation(self.coordsTail) == False:
            unit = calcUnit(calcVector(self.coordsCentroid, self.coordsTail))
            Response =+ unit
            print(" - Warning: Detected impassable at tail")
            
        if self.checkLocation(self.coordsLeft) == False:
            unit = calcUnit(calcVector(self.coordsCentroid, self.coordsLeft))
            Response =+ unit
            print(" - Warning: Detected impassable at left")

        if self.checkLocation(self.coordsRight) == False:
            unit = calcUnit(calcVector(self.coordsCentroid, self.coordsRight))
            Response =+ unit
            print(" - Warning: Detected impassable at right")
            
        if self.checkLocation(self.coordsTop) == False:
            unit = calcUnit(calcVector(self.coordsCentroid, self.coordsTop))
            Response =+ unit
            print(" - Warning: Detected impassable at top")

        if self.checkLocation(self.coordsBottom) == False:
            unit = calcUnit(calcVector(self.coordsCentroid, self.coordsBottom))
            Response =+ unit
            print(" - Warning: Detected impassable at bottom")
        
        return Response
    
    def calcWeightings(self):
        # This function determines the movement weightings for the individual
        # based on the movement results.         
    
        if self.minEnergy == -1:
            self.weight_minMaxEnergy = 0.1
        elif self.minEnergy == 0:
            self.weight_minMaxEnergy = 0.0
        else:
            self.weight_minMaxEnergy = 0.3 - 0.1*self.minEnergyIt*data['fishTimestep']
            
        if self.fallback == False:
            W_followFlow = self.weight_followFlow
            W_minMaxEnergy = self.weight_minMaxEnergy
            W_randWalk = self.weight_randWalk
            W_obAvoid = self.weight_obAvoid
            W_colAvoid = self.weight_colAvoid
            W_memory = self.weight_memory
            W_tkeAvoid = self.weight_tkeAvoid
        
        elif self.escape == True:
            # if escape is true, individual should move outside to the local 
            # flow direction and local memory direction
            W_followFlow = self.weight_followFlow
            W_minMaxEnergy = 0.0
            W_randWalk = self.weight_randWalk
            W_obAvoid = 0.0
            W_colAvoid = 0.0
            W_memory = self.weight_memory
            W_tkeAvoid = 0.0
            
        elif self.fallback == True:
            # if fallback is true, individual should follow the flow to move
            # downstream.
            W_followFlow = -self.weight_followFlow
            W_minMaxEnergy = 0.0
            W_randWalk = self.weight_randWalk
            W_obAvoid = 0.0
            W_colAvoid = 0.0
            W_memory = 0.0
            W_tkeAvoid = 0.0
        
        # Set weights
        self.movementWeights = ([W_followFlow,
                                 W_minMaxEnergy,
                                 W_randWalk,
                                 W_obAvoid,
                                 W_colAvoid,
                                 W_memory,
                                 W_tkeAvoid])
        
        preSUM = np.sum(self.movementWeights)
        for i in range(len(self.movementWeights)):
            self.movementWeights[i] = self.movementWeights[i]/preSUM
    
    def calcMoveMakeMove(self,avgMagVel,avgVecVal):
        # This function calculates the new location after the fish has made 
        # its move. However, if the fish is fallback the moveVector is 
        # calculated slightly differently (to fallback).
        # After finding the new location the function checks to see if the new 
        # location would result in a failure and, if so, doesn't move instead 
        # reverting the individual to fallback behaviour.
        
        print('\n - fallback is          %s' % self.fallback)
        print(' - escape is            %s' % self.escape)
        
        # Set speed to sustained unless local velocity is greater than sustained
        # velocity, where speed is set to burst.
        if self.escape == True:
            speed = self.swimBurst
            self.escape == False
        elif self.fallback == True:
            speed = 0.0
            if avgMagVel <= 0.1:
                speed = self.swimSust
        else:
            speed = self.swimSust
            if burstSwitch == True:
                if avgMagVel >= self.swimSust:
                    speed = self.swimBurst
                # If the local velocity is greater than burst, fallback.
                elif avgMagVel >= self.swimBurst:
                    self.fallback = True
                    speed = 0.0
        
        self.moveVector = ((self.moveDirection*speed)+avgVecVal)*data['fishTimestep']
        
        if self.fallback == True:
            self.fallbackIt += 1
            if self.fallbackIt == self.maxFallbackIt:
                self.fallback = False
                self.fallbackIt = 0
                self.fallbackCount += 1
                
        #Store new location as tempLocation
        tempLocation = self.coordsCentroid + self.moveVector
        
        if self.checkLocation(tempLocation) == True:
            # If the new location is fine, move there. 
            print(' - Final movement vector is %s' % self.moveVector)
            # Update coordinates
            self.coordsCentroid += self.moveVector
        else:
            # Otherwise, reduce the movement vector and redefine the temporary 
            # coordinates then recheck if temp coords are appropriate.
            # Do this a number of times. If no suitable location is found,
            # the individual does not move. 
            k = 1
            kmax = 8
            print('actual coords %s' % fishes[j].coordsCentroid)
            print('temp location %s' % tempLocation)
            while not self.checkLocation(tempLocation) and k < kmax:
                # While the tempLocation is unsuitable, redefine by reducing
                # the movement vector by reduceStep**k
                # print(k)
                self.moveVector = (self.reduceStep**k)*self.moveVector
                tempLocation = self.coordsCentroid + self.moveVector
                k += 1
            
            if self.checkLocation(tempLocation) == True:
                # Double check that location is acceptable, if so move. 
                print(' - Final movement vector is %s' % self.moveVector)
                # Update coordinates
                self.coordsCentroid += self.moveVector
                            
            else: 
                # If after the reduction steps, the location is no suitable,
                # the individual doesn't move
                self.fallback = True
                #self.escape = True
                print(' - No movement')
    
    def checkLocation(self,location):
        #print(location)
        #checks a possible location's appropriateness.
        if interrogateInDomain(location[0],location[1],location[2]):
            if interrogatePassability(location[0],location[1],location[2]) > data['passabilityThreshold']:
                return True
            else:
                return False
        else:
            return False
            
    def followFlow(self):
        # Assess the 3D velocity units at both the physical nodes and at the 
        # sensory ovoid nodes. Returns the average negative direction i.e.
        # direction that fish should move in and the magnitude of the velocity
        # that must be overcome. 
                
        # Assess 3D velocity units at physical nodes:
        # Nose, Tail, Left, Right, Top, Bottom
        # Then at sensory oviod nodes:
        # NoseSO, TailSO, LeftSO, RightSO, TopSO, BottomSO
        velocities = np.zeros([18,3])
        velocityMags = np.zeros(6)
        velocityMagsScaled = np.zeros(6)
        velocityMagsSO = np.zeros(12)
        velocityMagsSOScaled = np.zeros(12)
        velocityUnits = np.zeros([18,3])
        velocityBinaries = np.zeros(18)
        physicalNames = ['coordsNose', 'coordsTail', 'coordsLeft', 'coordsRight', 'coordsTop', 'coordsBottom'] 
        ovoidNames = ['coordsNoseSO', 'coordsTailSO', 'coordsLeftSO', 'coordsRightSO', 'coordsTopSO', 'coordsBottomSO',
                      'coordsNoseSOMid', 'coordsTailSOMid', 'coordsLeftSOMid', 'coordsRightSOMid', 'coordsTopSOMid', 
                      'coordsBottomSOMid']
        
        for i in range(len(physicalNames)):
            x,y,z = getattr(fishes[self.id], physicalNames[i])
            # Check node is in domain
            if interrogateInDomain(x,y,z) and interrogatePassability(x,y,z)>data['passabilityThreshold']:
                # Interpolate velocity onto node
                velocities[i] = interrogateVelocity3D(x,y,z)
                velocityUnits[i] = calcUnit(velocities[i])
                velocityMags[i] = calcMagnitude(velocities[i])
                velocityMagsScaled[i] = np.exp(calcMagnitude(velocities[i]))-1
                velocityBinaries[i] = 1.0
            else:
                velocities[i] = 0.0
                velocityUnits[i] = np.array([0.0, 0.0, 0.0])
                velocityMags[i] = 0.0
        
        for i in range(len(ovoidNames)):
            x,y,z = getattr(fishes[self.id], ovoidNames[i])
            # Check node is in domain
            if interrogateInDomain(x,y,z) and interrogatePassability(x,y,z)>data['passabilityThreshold']:
                # Interpolate velocity onto node
                velocities[6+i] = interrogateVelocity3D(x,y,z)
                velocityUnits[6+i] = calcUnit(velocities[6+i])
                velocityMagsSO[i] = calcMagnitude(velocities[i])
                velocityMagsSOScaled[i] = np.exp(calcMagnitude(velocities[i]))-1
                velocityBinaries[6+i] = 1.0
            else:
                velocities[6+i] = 0.0
                velocityUnits[6+i] = np.array([0.0, 0.0, 0.0])
                velocityMagsSO[i] = 0.0

        # Scale the velocity directions with their associated magnitude to ensure
        # faster flows influence the direction more.
        if scaledVels == True:
            for i in range(6):
                velocityUnits[i] = velocityUnits[i] * velocityMagsScaled[i]
            for i in range(12):
                velocityUnits[6+i] = velocityUnits[6+i] * velocityMagsSOScaled[i]
                        
        velUnitAvg = calcUnit(sum(velocityUnits)/sum(velocityBinaries))
        velMagAvg = sum(velocityMags)/len(velocityMags)
        velVecAvg = sum(velocities)/sum(velocityBinaries)
        
        moveDir = -velUnitAvg
        
        return np.array([moveDir, velMagAvg, velVecAvg])
     
    def minMaxEnergyFunc(self,avgMagVel):
        # This rule searches for either the minimum energy or the maximum
        # energy and pushes the individual in that direction. The min/max 
        # decision depends on if the individual finds an obstacle (search for 
        # max) or experiences a vel mag lower than a threshold value (search
        # for max). If neither of this cases are true, the individual searches
        # for min energy.
        
        # If the individual is already looking for max energy, the only way to 
        # revert to min energy is once the count reaches a max.
        # First check whether the individual is already searching for max
        if self.minEnergy == 1:
            # If searching for max energy, check if count is at max. If less 
            # than max, increase count.
            if self.minEnergyIt < self.minEnergyItMax:
                    self.minEnergyIt += 1
            else:
            # if at max, revert to minEnergy = True and reset count.
                self.minEnergy = -1
                self.minEnergyIt = 0  
        
        else:
            # If the individual isn't searching for max already, determine what
            # it wants to search for.
            #Determine if an obstacle is found.
            if sum(self.obAvoidRule()) != 0:
                #obstacle = True
                obstacle = True
            else:
                obstacle = False
        
            #Determine if low magnitude found.
            if avgMagVel <= self.maxEnergyThreshold:
                lowMag = 1
            elif avgMagVel <= self.minEnergyThreshold:
                lowMag = 0 
            else:
                lowMag = -1
            
            # Check for large negative velocity gradients.
            if self.calcVelGrads() == 1:
                grad = 1
            elif self.calcVelGrads() == 0:
                grad = 0
            else:
                grad = -1
        
            # Check for an obstacle or a lowMag or a large negative gradient.
            # If it exists, search for maximum energy path
            if obstacle or lowMag == 1 or grad == 1:
                print('\n - Obstacle is     %s' % obstacle)
                print(' - lowMag is         %s' % lowMag)
                print(' - grad is           %s' % grad)
                # if either is True, set minEnergy = 1 (max) and reset count.
                self.minEnergy = 1
                self.minEnergyIt = 0
        
            #Otherwise, see if the lowMag or Grad are suitable to stop search
            elif lowMag == 0 or grad == 0:
                self.minEnergy = 0
        
            # Else, search for minimum energy
            else: 
                self.minEnergy = -1
        
        # Now determine the velocities in the fish's heading.
        #Names = ['coordsNoseSOFront', 'coordsLeftFront', 'coordsRightFront', 'coordsNoseSOMid', 'coordsLeftSOMid', 'coordsRightSOMid']
        Names = ['coordsNoseSOFront', 'coordsLeftFront', 'coordsRightFront']
        velocitiesFront = np.zeros([len(Names),3])
        velocityMagsFront = np.zeros(len(Names))
        
        if self.minEnergy == -1: 
            # If minEnergy is True, then determine the location of the smallest
            # velocity,
            for i in range(len(Names)):
                x,y,z = getattr(fishes[self.id], Names[i])
                if interrogateInDomain(x,y,z) and interrogatePassability(x,y,z)>data['passabilityThreshold']:
                    velocitiesFront[i] = interrogateVelocity3D(x,y,z)
                    velocityMagsFront[i] = calcMagnitude(velocitiesFront[i])
                else:
                    velocitiesFront[i] = 999
                    velocityMagsFront[i] = 999
            # Pick the smallest
            idx = np.argmin(velocityMagsFront)
            # Now determine the unit vector describing the direction towards the 
            # min/max velocity and return the unit vector.
            x,y,z = getattr(fishes[self.id], Names[idx])
            velUnitFront = calcUnit(calcVector([x,y,z],self.coordsCentroid))
            
        elif self.minEnergy == 0:
            # if minEnergy is 0, individual is not searching based on energy.
            # i.e. it is happy where it is
            velUnitFront = np.array([0.0,0.0,0.0])
            
        elif self.minEnergy == 1:
            # If minEnergy is False, then determine the location of the 
            # largest velocity,
            for i in range(len(Names)):
                x,y,z = getattr(fishes[self.id], Names[i])
                if interrogateInDomain(x,y,z) and interrogatePassability(x,y,z)>data['passabilityThreshold']:
                    velocitiesFront[i] = interrogateVelocity3D(x,y,z)
                    velocityMagsFront[i] = calcMagnitude(velocitiesFront[i])
                else:
                    velocitiesFront[i] = -999
                    velocityMagsFront[i] = -999    
            #Pick the largest
            idx = np.argmax(velocityMagsFront)
            # Now determine the unit vector describing the direction towards the 
            # min/max velocity and return the unit vector.
            x,y,z = getattr(fishes[self.id], Names[idx])
            velUnitFront = calcUnit(calcVector([x,y,z],self.coordsCentroid))
        
        
        return velUnitFront
    
    def calcVelGrads(self):
        # Determines the gradient of the main velocity component in each axis
        # If the velocity gradient is negative and large, the minEnergy rule
        # modified to direct the inidividual to steer away from the large 
        # negative gradient
        velocitiesLeft = np.zeros([3,3])   
        gradLeft = np.zeros([2,1])
        LeftNames = ['coordsLeft', 'coordsLeftSOMid', 'coordsLeftSO']
        for i in range(len(LeftNames)):
            x,y,z = getattr(fishes[self.id], LeftNames[i])
            if interrogateInDomain(x,y,z) and interrogatePassability(x,y,z)>data['passabilityThreshold']:
                velocitiesLeft[i] = interrogateVelocity3D(x,y,z)
            else:
                # If the point is invalid, set velocities to zero
                velocitiesLeft[i] = np.array([0.0,0.0,0.0])
    
        velocitiesRight = np.zeros([3,3])   
        gradRight = np.zeros([2,1])
        RightNames = ['coordsRight', 'coordsRightSOMid', 'coordsRightSO']
        for i in range(len(RightNames)):
            x,y,z = getattr(fishes[self.id], RightNames[i])
            if interrogateInDomain(x,y,z) and interrogatePassability(x,y,z)>data['passabilityThreshold']:
                velocitiesRight[i] = interrogateVelocity3D(x,y,z)
            else:
                # If the point is invalid, set velocities to zero
                velocitiesRight[i] = np.array([0.,0.,0.])
        
        # Determine which the main velocity component is (at centre)
        idx = np.argmax(interrogateVelocity3D(self.coordsCentroid[0],self.coordsCentroid[1],self.coordsCentroid[2]))
        
        # Throw away other data
        velocitiesLeft = [velocitiesLeft[0][idx],velocitiesLeft[1][idx],velocitiesLeft[2][idx]]
        velocitiesRight = [velocitiesRight[0][idx],velocitiesRight[1][idx],velocitiesRight[2][idx]]
        # Check for zero magnitudes
        if velocitiesLeft[0] == 0.0 or velocitiesLeft[1] == 0.0:
            gradLeft[0] = 0.0
        else:
            gradLeft[0] = (velocitiesLeft[1]-velocitiesLeft[0])/(self.SQDx/2)
            
        if velocitiesLeft[1] == 0.0 or velocitiesLeft[2] == 0.0:
            gradLeft[1] = 0.0
        else:
            gradLeft[1] = (velocitiesLeft[2]-velocitiesLeft[1])/(self.SQDx/2)
            
        if velocitiesRight[0] == 0.0 or velocitiesRight[1] == 0.0:
            gradRight[0] = 0.0
        else:
            gradRight[0] = (velocitiesRight[1]-velocitiesRight[0])/(self.SQDx/2)

        if velocitiesRight[1] == 0.0 or velocitiesRight[2] == 0.0:
            gradRight[0] = 0.0
        else:
            gradRight[0] = (velocitiesRight[2]-velocitiesRight[1])/(self.SQDx/2)            
        
        grads = np.concatenate((gradLeft,gradRight))
        #print(grads)
        idx = np.argmin(grads)
        if grads[idx] < self.gradThresholdMax:
            # Gradient is negative and large. Therefore, search for larger velocities
            largeGrad = 1            
        elif grads[idx] < self.gradThresholdMin:
            # Gradient is negative and moderate. Therefore, stop
            # searching for minimum energy.
            largeGrad = 0
        else:
            # Gradient is acceptable, so do nothing
            largeGrad = -1
        
        return largeGrad
            
#    def buoyRule(self):
#        # Calculate the buoyancy force based on an acclimatised depth.
#        move2 = np.array([0.0, 0.0, 0.0])
#        # Vn is volume required for neutral buoyancy
#        Vn = self.VolConstant/(data['AtmPressure']+(data['rho_w']*9.81*(dom['maxZ']-self.z)))
#        # B is buoyancy force, based on required volume versus actual volume.
#        B = (self.Vs - Vn)*data['rho_w']*9.81
#        # a is acceleration
#        a = B/self.mass
#        # Resulting movement vector (NOT UNIT) just in vertical direction
#        move2[2] = a*data['fishTimestep']*data['fishTimestep']
#        return move2
#    
#    # Rule 2 returns the VECTOR (NOT UNIT), which is a force result due 
#        # to buoyancy.
#        buoyancy = self.buoyRule()
#        runResults.append(buoyancy)
    
    def obAvoidRule(self):
        # Obstacle avoidance. First create a list of the points used - see 
        # diagram for location of these points.
        frontNodes = list()
        frontNodes.append(self.coordsNoseSOFront)
        frontNodes.append(self.coordsLeftFront)
        frontNodes.append(self.coordsRightFront)
        frontNodes.append(self.coordsTopFront)
        frontNodes.append(self.coordsBottomFront)
        # Create array of point passabilities
        frontNodepassabilities = np.zeros(len(frontNodes))
        # Define variables
        jmax = 20
        j = 1
        
        while np.sum(frontNodepassabilities) == 0 and j < jmax:
            # for each point, check they are in the domain. if they are find their 
            # passability. If they aren't, set passability to zero.
            for i in range(len(frontNodes)):
                if interrogateInDomain(frontNodes[i][0], frontNodes[i][1], frontNodes[i][2]) == True:
                    frontNodepassabilities[i] = interrogatePassability(frontNodes[i][0], frontNodes[i][1], frontNodes[i][2])
                else:
                    frontNodepassabilities[i] = 0.0        
                
                frontNodepassabilities[i] = polarise(frontNodepassabilities[i], 0.1)
                
            # Call function to redefine front locations slightly closer to the 
            # fish body
            fishes[self.id].findFrontLocations(0.8*(0.67/(j)))
            
            #Redefine coords based on new front points
            frontNodes = list()
            frontNodes.append(self.coordsNoseSOFront)
            frontNodes.append(self.coordsLeftFront)
            frontNodes.append(self.coordsRightFront)
            frontNodes.append(self.coordsTopFront)
            frontNodes.append(self.coordsBottomFront)
            j +=1
            
        if j == jmax:
            obAvoid = np.array([0.0,0.0,0.0])
        else:
            # for each point, multiply the coordinates of the point by its passability. 
            # This effectively removes the unpassable points.
            sumX = 0.0
            sumY = 0.0
            sumZ = 0.0
        
            for i in range(len(frontNodes)):
                frontNodes[i] = frontNodes[i]*frontNodepassabilities[i]
                # Sum together the x, y, and z coordinates for each (after multiplying 
                # by passability). then divide by number of passable points. This finds the 
                # average coordinate of the passable points.
                sumX += frontNodes[i][0]
                sumY += frontNodes[i][1]
                sumZ += frontNodes[i][2]
            
            X = sumX/np.sum(frontNodepassabilities)
            Y = sumY/np.sum(frontNodepassabilities)
            Z = sumZ/np.sum(frontNodepassabilities)
            # This defines the location the individual wants to move towards
            move2coords = ([X,Y,Z])
        
            # Find the vector between current fish position and the calculated 
            # average passable point and calculate the unit vector.
            # If all nodes are obstacles (0) or all nodes are clear (5) set
            # move direction to 0 as fish has no preference.
            if sum(frontNodepassabilities) == 0 or sum(frontNodepassabilities) == len(frontNodes):
                obAvoid = np.array([0.0,0.0,0.0])
            else:                
                vector = calcVector(move2coords,self.coordsCentroid)
                obAvoid = calcUnit(vector)
                    
        # Returns the direction the fish wants to move in.
        return obAvoid
    
    def colAvoidRule(self):
        # Repel from any fish within repulsionDist
        # Define variables
        dist = np.zeros(data['fishNum'])
        directions = np.zeros([data['fishNum'],3])
        directionsBinary = np.zeros(data['fishNum'])
        # Find distance between centroid of current fish and all others.
        for j in range(0, data['fishNum']):
            dist[j] = calcDistance(self.coordsCentroid, fishes[j].coordsCentroid)
            
            # if any distance is more than zero (so not stimulating self), less
            # than the repulsion distance, and the fish is in the domain.
            if fishes[j].id not in passed and fishes[j].id not in failed and dist[j] > 0.0 and dist[j] < data['repulsionDist']:
                # Determine the unit vector to describe the direction of the 
                # other individual, store it, and use directionBinary to note
                # the number of influencing fish.
                vect = calcVector(self.coordsCentroid, fishes[j].coordsCentroid)
                print(vect)
                directions[j] = calcUnit(vect)
                directionsBinary[j] = 1
        # Determine average direction of influencing individuals and move in
        # the opposite direction. And check to make sure result is unit vector.      
        if sum(directionsBinary) > 0:
            directionAvg = calcUnit(sum(directions)/sum(directionsBinary))
        else:
            directionAvg = np.array([0., 0., 0.])
        return directionAvg
       
    def tkeAvoid2(self):    
        # If the individual senses tke above a threshold. It will seek to 
        # minimise TKE in its path selection
        
        Names = ['coordsNoseSOFront', 'coordsLeftFront', 'coordsRightFront']
        TKEFront = np.zeros(len(Names))
        unit =  calcUnit(calcVector(self.coordsCentroid,self.coordsCentroid))
        
        for i in range(len(Names)):
            x,y,z = getattr(fishes[self.id], Names[i])
            if interrogateInDomain(x,y,z) and interrogatePassability(x,y,z)>data['passabilityThreshold']:
                TKEFront[i] = interrogateTKE(x,y,z)
            else:
                TKEFront[i] = -999
            
            if np.max(TKEFront) > self.tkeThreshold:
                # Pick the smallest absolute
                idx = np.argmin(np.abs(TKEFront))
                print(idx)
                print(Names[idx])
                # Now determine the unit vector describing the direction towards the 
                # min/max velocity and return the unit vector.
                x,y,z = getattr(fishes[self.id], Names[idx])
                unit = calcUnit(calcVector([x,y,z],self.coordsCentroid))
        
        return unit
    
    def randomWalk(self):
        # Returns a unit vector to describe a random walk direction. 
        # Limited to value angleLimit for angle in vertical axis.
        
        radius = 1
        z = 1
        angleLimit = 0.175 #10 degrees
        
        x = 0.0
        y = 0.0
        z = -9999990.0
        
        while np.abs(z-self.coordsCentroid[2]) > np.sin(angleLimit):
            theta = random.uniform(0,2*np.pi)
            phi = random.uniform(0,2*np.pi)
        
            x = self.coordsCentroid[0] + radius*np.sin(theta)*np.cos(phi)
            y = self.coordsCentroid[1] + radius*np.sin(theta)*np.sin(phi)
            z = self.coordsCentroid[2] + radius*np.cos(theta)
        
        point = np.array([x,y,z])
        vector = calcVector(self.coordsCentroid, point)
        unit = calcUnit(vector)
        
        return unit
    
    def memoryRule(self):
        # calculate the average velocity direction from the memory array
        direction = np.zeros([3])
        direction[0] = np.average(self.dirMemory[:,0])
        direction[1] = np.average(self.dirMemory[:,1])
        direction[2] = np.average(self.dirMemory[:,2])
        
        unit = calcUnit(direction)
        
        return unit
    
    def updateMemory(self,timestep,avgVelVec):
        # THis function updates the individual's memory array
        newDir = -calcUnit(avgVelVec) 
        idx = timestep % int(round(self.numMemory,1))
        self.dirMemory[idx,:] = newDir
    
    def energyExpended0(self, distMoved, Mag):
        # Calculates energy expended each timestep using the Khan (2006) method. 
        alpha = 0.0444
        U = self.swimSust + Mag
        Re = calcReynoldsNumber(U,self.bodylength)
        E = alpha * data['rho_w'] * self.wettedArea * U**2 * distMoved * Re**(-0.2)
        
        return E
    
    def failCheck(self):
        # Checks whether a fish is about to fail
        # Checks for zero magnitude velocities. 
        
        fail = False
        
        if len(self.history) > 5:
            if self.distMovedHistory[-1] + self.distMovedHistory[-2] + self.distMovedHistory[-3] + self.distMovedHistory[-4] + self.distMovedHistory[-5]== 0.0:
                fail = True
                self.failReason = 'no movement'
        if interrogatePassability(self.coordsCentroid[0], self.coordsCentroid[1], self.coordsCentroid[2]) < data['passabilityThreshold']:
            fail = True
            self.failReason = 'low passability'
        if self.fallbackCount > data['fallbackMax']:
            fail = True
            self.failReason = 'number of fallbacks'
        return fail
    
    def fail(self):
        # Moves the individual to the failed list and warns the user.
        
        failed.append(fishes[j].id)
        print('\n - Fish ID %s has failed due to' % self.id, self.failReason)
        print('- %s fish failed to pass.' % len(failed))
        print(' - %s fish passed successfully.' % len(passed))
        print(' - %s left in the domain. \n' % (data['fishNum'] - (len(passed)+len(failed))))
    
    def passCheck(self):
        # Checks whether a fish has passed through the domain
        # Criteria may change but for now it's just based on location. 
        
        pass2 = False
        
        centrePass = False
        centre_point_x = False
        centre_point_y = False
        centre_point_z = False
        
        if self.coordsCentroid[0] < data['targetZoneXmax'] and self.coordsCentroid[0] > data['targetZoneXmin']:
            centre_point_x = True
        if self.coordsCentroid[1] < data['targetZoneYmax'] and self.coordsCentroid[1] > data['targetZoneYmin']:
            centre_point_y = True
        if self.coordsCentroid[2] < data['targetZoneZmax'] and self.coordsCentroid[2] > data['targetZoneZmin']:
            centre_point_z = True   
        
        nosePass = False
        nose_point_x = False
        nose_point_y = False
        nose_point_z = False
        
        if self.coordsNose[0] < data['targetZoneXmax'] and self.coordsNose[0] > data['targetZoneXmin']:
            nose_point_x = True
        if self.coordsNose[1] < data['targetZoneYmax'] and self.coordsNose[1] > data['targetZoneYmin']:
            nose_point_y = True
        if self.coordsNose[2] < data['targetZoneZmax'] and self.coordsNose[2] > data['targetZoneZmin']:
            nose_point_z = True   
        
        if centre_point_x == True and centre_point_y == True and centre_point_z == True:
            centrePass = True
        if nose_point_x == True and nose_point_y == True and nose_point_z == True:
            nosePass = True
            
        if centrePass == True or nosePass == True:
            pass2 = True
            
        return pass2
    
    def pass2(self):
        # If passed, add fish to passed list and tell user.
        passed.append(fishes[j].id)
        print('\n - Fish ID %s has passed' % self.id)
        print(' - %s fish failed to pass.' % len(failed))
        print(' - %s fish passed successfully.' % len(passed))
        print(' - %s left in the domain. \n' % (data['fishNum'] - (len(passed)+len(failed))))
                            
#%%        
############################################################################### MAKE FISH
# Creates a list called "fishes" in which the fish are built.
fishes = list()
# Create fish
print(" ------------------------------------------------------- ")
print("                 ----- Creating Fish -----               \n")

ID = 0


for i in range(data['fishNum']): 
    if DEBUG == True:
        # If debug mode active, create fish in set fish creation locations.
        fishes.append(fish(ID, debug['creations'][ID][0], debug['creations'][ID][1], debug['creations'][ID][2]))
        # Find the node locations.
        fishes[ID].findVertexLocations() #QUERY
        fishes[ID].findFrontLocations(0.8) #QUERY
        # Check new nodes are acceptable.
        fishes[ID].domainLimits() #QUERY
        
        print(' --- Fish successfully created')
        ID += 1
    
    else:
        # Determine an appropriate creation location
        createLoc = determineCreationLocation(ID,data['bodylength_mean'])
        # Check whether the location is truly acceptable
        checkLoc = checkCreationLocation(ID,data['bodylength_mean'],createLoc[0],createLoc[1],createLoc[2])
    
        # If acceptable, create a fish at that location
        if checkLoc == True:
            fishes.append(fish(ID, createLoc[0], createLoc[1], createLoc[2]))
                       
            # Find the node locations.
            fishes[ID].findVertexLocations() #QUERY
            fishes[ID].findFrontLocations(0.8) #QUERY
    
            # Check new nodes are acceptable.
            fishes[ID].domainLimits() #QUERY
        
            print(' --- Fish successfully created')
            ID += 1
    
        # Otherwise pass
        else:
            print(' --- Fish not created')
            pass

print("\n ------------------------------------------------------- ")

if len(fishes) != data['fishNum']:
    data['fishNum'] = len(fishes)
    print("            ----- Fish creation limited! -----           ")
    print("               ----- %s fish created. -----              " % data['fishNum'])
    print("   ----- Can't find suitable creation locations -----    ")
    print("  --- If more fish required, define larger creation      ")
    print("      zone or turn off collisions.                       ")
else:
    print("               ----- All Fish Created -----            ")    

print(" ------------------------------------------------------- \n")

#%%
############################################################################### DEFINE PASSED/FAILED FISH
# Once a fish has made in upstream it is put into the "passed" list. 
passed = list()

# If for any reason a fish is deemed to be unable to pass, it will be added to the failed list.
failed = list()     

############################################################################### DEFINE INITIATIVE
# The initiative list controls the order that the fish move in.  
initiative = range(data['fishNum'])

#%%
############################################################################### MAIN LOOP
# Initialise iteration count
i = 0
timesteps = i

# while iteration count is low and the number of fish passed is not equal to the total number of fish.
while i < data['Tmax'] and (len(passed)+len(failed)) != data['fishNum']:

    # For all the fish in the initiative list.
    # This controls the order of movement each turn.
    for j in initiative:

        # If a fish is known to have passed or failed skip it.
        if fishes[j].id in passed or fishes[j].id in failed:
            pass
    
        # If the fish hasn't passed or failed then commence main loop below.
        else:
            # If failed, add fish to failed list and warning the user.
            if fishes[j].failCheck() == True :
                fishes[j].fail()
            
            # Otherwise commence movement loop.
            else:
                #Run rules (including Print rule outputs)
                runResults, avgMagVel, avgVecVal = fishes[j].runRules(printer=True, i=i)
                #Determine movement direction by picking response.                
                fishes[j].moveDirection = fishes[j].deterResponse(runResults)
                #Determine actual movement of the individual and make the 
                #move.
                fishes[j].calcMoveMakeMove(avgMagVel,avgVecVal)
 
                # Check whether the fish has now passed based on new coordinates.
                if fishes[j].passCheck() == True:
                    fishes[j].pass2()

                # Check for NaNs and warn user.
                if np.isnan(sum(fishes[j].coordsCentroid)) is True:
                    print('\n--- Warning: NaN found. EXITING.')

                # Check the new centroid position is acceptable.
                fishes[j].domainLimitsCentroid()
        
                # Calculate the new heading of the fish
                fishes[j].calcHeading()
                print(' - New heading is    %s' % fishes[j].heading)
                
                # Calculate the location of the other nodes based on the new
                # centroid position and the new heading.
                fishes[j].findVertexLocations()
                fishes[j].findFrontLocations(0.8)
                #fishes[j].findNodes()
                #fishes[j].vertexLocations()
        
                # Check that the new node locations are acceptable.
                fishes[j].domainLimits()
        
                # Calculate the energy expended in the last timestep and add 
                # it to the total tally.
                fishes[j].totalEnergy0 += fishes[j].energyExpended0(distMoved[j],avgMagVel)
        
                # Update the memory array with the spatially averaged
                # velocity direction this timestep.
                fishes[j].updateMemory(i,avgVecVal)
                
                # Count the time taken.
                fishes[j].totalTime += data['fishTimestep']
    
                # Update history
                fishes[j].history = np.vstack([fishes[j].history, fishes[j].coordsCentroid])
                
        # Calculate distance moved in timestep and update distMovedHistory
        if fishes[j].id in passed or fishes[j].id in failed:
            fishes[j].distMoved = 0.0
        else:
            fishes[j].distMoved = calcDistance(fishes[j].coordsCentroid, fishes[j].history[-2])
            fishes[j].distMovedHistory = np.vstack([fishes[j].distMovedHistory, fishes[j].distMoved])
    
        # List of most recent distances moved. Used for the initiative function.
        distMoved[j] = fishes[j].distMoved

        # Summation of distance moved to date.
        fishes[j].distMovedTotal += fishes[j].distMoved
    
    # Calculate new initiative list
    initiative = initiativeFunction(distMoved, data['fishNum'])

    # Increase step count
    i += 1
    timesteps +=1

    if i == data['Tmax']:
        print('\n' + ' --- Warning: Timed Out')
        for j in range(data['fishNum']):
            if j not in passed and j not in failed:
                fishes[j].failReason = 'timed out'
                fishes[j].fail()


#%%
############################################################################### CALCULATE METRICS (OUTSIDE LOOP)
 # Predefine average distance moved
avgDistMoved = 0
for k in range(data['fishNum']):
    avgDistMoved += fishes[k].distMovedTotal
avgDistMoved = round(avgDistMoved/data['fishNum'],)

timeSimulated = round(timesteps*data['fishTimestep'], 2)      


#%%
############################################################################### PRINTS
print("\n \n" + " ------------------------------------------------------- ")
if (len(passed)+len(failed)) == data['fishNum']:
    print("           ----- Finished successfully. -----            ")
else:
    print("          ----- Warning: Finished early. -----           ")
print("              ----- Writing to file. -----               ")
print("              ----- %s Fish Simulated -----              " % data['fishNum'])
print("               ----- %s Fish Passed -----                " % len(passed))
print("               ----- %s Fish Failed -----                " % len(failed))
print("     ----- The code took %s seconds to execute -----     " % (round(time.perf_counter() - start_time)))
print(" ------------------------------------------------------- " + "\n")
print(" ------------------------------------------------------- ")
print("           ---- %   s seconds simulated ----            " % timeSimulated)
print("        ---- Average %s metres travelled ----         " % avgDistMoved)
print(" ------------------------------------------------------- " + "\n")        


############################################################################### WRITE OUTPUT FILE
if WRITE == True:
    file = open("outputs/passageMetrics.txt","w")
    file.write('This file details the bulk passage metrics from the creationZone to the targetZone. \n')
    file.write('total fish requested: ' + repr(data['fishNumReq']) + '\n')
    file.write('total fish simulated: ' + repr(data['fishNum']) + '\n')
    file.write('total successful passages: ' + repr(len(passed)) + '\n')
    file.write('total failed passages: ' + repr(len(failed)) + '\n')
    file.write('predicted passage efficiency of domain: ' + repr((len(passed)/data['fishNum'])*100) + '\n')
    file.close()
    
    
    for k in range(data['fishNum']):    
        file = open("outputs/fishPathOutput" + repr(k) +".txt","w")
        file.write('fish ' + repr(k) + '\n') 
        file.write('fish variables \n')
        file.write('creation location: ' + repr(fishes[k].history[0]) + ' (m) \n')
        file.write('bodylength: ' + repr(np.round(fishes[k].bodylength,5)) + ' m \n')
        file.write('burst speed: ' + repr(np.round(fishes[k].swimBurst,5)) + ' m/s \n')
        file.write('sustained speed: ' + repr(np.round(fishes[k].swimSust,5)) + ' m/s \n')
        
        file.write('\n' + 'model parameters \n')
        file.write('timestep: ' + repr(data['fishTimestep']) + ' (s) \n')
        file.write('Tmax: ' + repr(data['Tmax']) + '\n')
        file.write('sensoryRange: ' + repr(data['SensoryRange']) + ' (bodylengths) \n')
        file.write('repulsionDist: ' + repr(data['repulsionDist']) + ' (m) \n')
        file.write('fallbackMax: ' + repr(data['fallbackMax']) + '\n')
        
        file.write('\n' + 'rule settings \n')
        file.write('debug is ' + repr(DEBUG)+ '\n')
        file.write('followFlow is '+ repr(followFlowSwitch)+ '\n')
        file.write('minMaxEnergy is '+ repr(minMaxEnergySwitch)+ '\n')
        file.write('randomwalk is '+ repr(randomWalkSwitch)+ '\n')
        file.write('obAvoidance is '+ repr(obAvoidanceSwitch)+ '\n')
        file.write('colAvoidance is '+ repr(colAvoidanceSwitch)+ '\n')
        file.write('memory is '+ repr(memorySwitch)+ '\n')
        file.write('tkeAvoidance is '+ repr(tkeAvoidanceSwitch)+ '\n')
        file.write('scaledVels is '+ repr(scaledVels)+ '\n')
        file.write('burstVels is '+ repr(burstSwitch)+ '\n \n')
        
        if fishes[k].id in passed:
            file.write('successful passage \n')
        else:
            file.write('failed passage. \n')
            file.write('failed due to ' + repr(fishes[k].failReason) + '\n')
        file.write('energy expended: ' + repr(np.round(fishes[k].totalEnergy0,5)) + ' J'+ '\n')
        file.write('total distance travelled: ' + repr(np.round(fishes[k].distMovedTotal,5)) + ' m'+ '\n')
        file.write('total time taken: ' + repr(np.round(fishes[k].totalTime,5)) + ' s'+'\n')
        file.write('final location: ' + repr(fishes[k].history[-1]) + ' (m) \n')
        file.write('number of fallbacks: ' +repr(fishes[k].fallbackCount)+'\n')
        file.flush()

    file.close()
    
    for k in range(data['fishNum']): 
        with open("outputs/fishPathOutput" +repr(k) + ".csv","w", newline = '') as fish_path:
            fish_path_writer = csv.writer(fish_path,delimiter=',',quoting=csv.QUOTE_NONE)
            fish_path_writer.writerow(['X', 'Y', 'Z'])
            for x in fishes[k].history:
                fish_path_writer.writerow([x[0],x[1],x[2]])
    

#%%
############################################################################### PLOTTING
if PLOT == True:

    print(" --- Plotting" + "...")

    fig = plt.figure()
    ax = fig.gca(projection='3d')
    ax.set_xlim([0, dom['maxX']])
    ax.set_ylim([0, dom['maxY']])
    ax.set_zlim([0, dom['maxZ']])
    ax.plot_surface(envData['X3D'][:,:,0], envData['Y3D'][:,:,0], envData['DEM'], rstride=10, cstride=10, color='g', alpha=0.4)
    ax.azim = -150
    ax.elev = 50
    
    ax.text(-2,-2,0,s='Simulated time taken: %s seconds' % (timesteps*data['fishTimestep']))
    if DEBUG is True:
        for j in range(data['fishNum']): 
            ax.plot(fishes[j].history[:,0], fishes[j].history[:,1], fishes[j].history[:,2], label='parametric curve', color=debug['plotColours'][j])
    else:
        for j in range(data['fishNum']): 
            ax.plot(fishes[j].history[:,0], fishes[j].history[:,1], fishes[j].history[:,2], label='parametric curve', color=np.random.rand(3,))
    plt.show()
    
    fig2 = plt.figure()
    plt.axis('scaled')
    plt.xlim([0, dom['maxX']])
    plt.ylim([0, dom['maxY']])
    for j in range(data['fishNum']):
        plt.plot(fishes[j].history[:,0], fishes[j].history[:,1], color=np.random.rand(3,))

#%%
############################################################################### END
print("\n" + " ------------------------------------------------------- ")
print("     ------- The code executed successfully. -------     ")
print(" ------------------------------------------------------- " + "\n")



