%MIT OpenCourseWare: https://ocw.mit.edu
%12.810 Dynamics of the Atmosphere (Spring 2023)
%License: Creative Commons BY-NC-SA 
%For information about citing these materials or our Terms of Use, visit: https://ocw.mit.edu/terms.

# Mountain wave forced by a ridge

# Fast fourier transform is used to determine spectral coefficients for the solution. 

# Topography should either be isolated or periodic such that the domain length is a multiple of the wavelength (otherwise get ringing)

# Adapted from Holton and Hakim matlab lee_wave_1.m routine: 
# -fixed issues with k versus kmin, and issue with use of FFT (need positive and negative wavenumbers or else get aliasing)
# -changed to Gaussian ridge for the pset problem

import numpy as np
import matplotlib.pyplot as plt
import matplotlib.cm as cm

print("Topographic wave for isolated ridge")

# User input parameter

u0 = input('give mean zonal wind in m/s  ')
u0 = float(u0)

# Set physical parameters

Lz = 15e3  # Depth of domain in m
L = 3e3    # Width of ridge in m as measured by standard deviation
Lx = 70e3  # Lx is the length of the domain in m
bv = 1.0e-4  # Buoyancy frequency squared s^-2
g = 9.81   # gravitational acceleration, m/s^2

# Numerical parameters

NL = 128#61   # Number of vertical gridpoints
N = 512   # Number of modes for Fourier transform

# Defining fields and variables

kmin = 2*np.pi/Lx # Lowest zonal wavenumber
zz = np.linspace(0,Lz,NL)
xx = np.linspace(0,Lx,N)
X,Z = np.meshgrid(xx,zz)
#s = 
xm = Lx/2   # Location of ridge
hx = 5e2*np.exp(-(xx-xm)**2/2/L**2)

# Fourier transform hx(x) to get hn(s)
hn = np.fft.fft(hx)
hn[0]=0+0.0j  # Set mean to zero

# ks
nsize = hx.size
ks = np.fft.fftfreq(nsize)*N*kmin

# Vertical wavenumber
m2 = bv/u0**2 - ks**2
m2=m2.astype(complex)

m = -np.sqrt(m2)
m[ks>0] = np.sqrt(m2[ks>0])
m[m2<0] = np.sqrt(m2[m2<0])

M = np.tile(m,(NL,1))
K = np.tile(ks,(NL,1))

wsurf = 1j*ks*u0*hn

WSURF = np.tile(wsurf,(NL,1))

U = np.zeros((NL,N)).astype(complex)
PSI = np.zeros((NL,N)).astype(complex)

W = WSURF*np.exp(1j*M*Z)
    
U[:,1:] = -M[:,1:]*W[:,1:]/K[:,1:]
PSI[:,1:] = 1j*U[:,1:]/M[:,1:]

UXZ = np.real(np.fft.ifft(U,axis=1))
WXZ = np.real(np.fft.ifft(W,axis=1))
PSIXZ = np.real(np.fft.ifft(PSI,axis=1)) - u0*Z
PSIXZ_ref = 0*np.real(np.fft.ifft(PSI,axis=1)) - u0*Z

print('max u prime = %.5f' % np.max(UXZ))

plt.figure(1)
cbar = np.max(np.abs(WXZ))
cs = plt.contourf(X/1000,Z/1000,WXZ,vmin=-cbar,vmax=cbar,cmap=cm.bwr,levels=128)
# plt.clabel(cs)
plt.colorbar(cs)
plt.title('vertical velocity (m/s)')
plt.xlabel( 'horizontal distance (km)')
plt.ylabel('height (km)')

plt.figure(2)
cs = plt.contour(X/1000,Z/1000,PSIXZ)
plt.contour(X/1000, Z/1000, PSIXZ_ref, linestyles = 'dashed')
plt.title('streamlines')
plt.xlabel( 'horizontal distance (km)')
plt.ylabel('height (km)')
plt.show()
