赞
踩
- #coding:utf-8
- from mpl_toolkits.mplot3d import axes3d
- import matplotlib.pyplot as plt
- import numpy as np
- import time
-
- def createMatrix( m, n):
- A = np.zeros( (n + 2,m + 2))
- Up = np.ones( (m+2,1)) * 100
- Down = np.ones((m+2, 1)) * 0
- Lf = np.ones((1, n + 2)) * 75
- Rt = np.ones((1, n + 2) )* 50
-
- A[0,:] = Up.ravel()
- A[n+1,:] = Down.ravel()
- A[:,0] = Lf.ravel()
- A[:, m +1] = Rt.ravel()
-
- return A
-
- def oneIter(A, r_lf, r_rt):
- a_size = A.shape
- m = a_size[1] - 2
- n = a_size[0] - 2
- #create init ImpMatrix M and b
- M = np.diag( np.ones((1,m)).ravel() * ( 1 + r_lf))
- M = M + np.diag( np.ones( (1, m-1)).ravel() * ( -1.0 * r_lf / 2), 1)
- M = M + np.diag( np.ones( (1, m-1)).ravel() * ( -1.0 * r_lf / 2), -1)
-
- B = A.copy()
- for j in range(1, n + 1 ):
- b = np.zeros((m,1))
- rowA = A[j,:]
- b[0] = b[0] + rowA[0] * r_lf / 2
- b[m-1] = b[m-1] + rowA[m-1] * r_lf /2
- for i in range(1, m+1):
- colA = A[j-1:j+1+1,i]
- b[i-1] = b[i-1] + r_rt / 2 * colA[0] + ( 1 - r_rt) * colA[1] + r_rt / 2 * colA[2]
- B[j,1:m+1] = np.linalg.solve(M, b).ravel()
-
- return B
-
- def computeA(m, n , rx, ry, iter):
- A = createMatrix(m,n)
- print 'total iter=%s' % (iter)
- for i in range(1, iter):
- print 'iter num=%s' % (i)
- A = oneIter(A, rx,ry)
- B = oneIter(np.transpose(A), ry, rx)
- A = np.transpose(B)
-
- return A
-
- def computeOneIter(A, m, n , rx, ry):
- A = oneIter(A, rx,ry)
- B = oneIter(np.transpose(A), ry, rx)
- A = np.transpose(B)
- return A
-
- def getStart():
- X_INTERVAL = [0,20]
- Y_INTERVAL = [0,30]
- T = [0,10]
- deltax = 0.5
- deltay = 0.3
- tao = 1.0 / 3 * min(deltax, deltay) * min(deltax, deltay)
- m = (X_INTERVAL[1] - X_INTERVAL[0]) / deltax - 1
- n = (Y_INTERVAL[1] - Y_INTERVAL[0]) / deltay - 1
- m = int(m)
- n = int(n)
- print 'm=%s,n=%s' % (m,n)
- x = np.linspace(X_INTERVAL[0], X_INTERVAL[1], m)
- y = np.linspace(Y_INTERVAL[0], Y_INTERVAL[1], n)
- #A = computeA(m,n,tao/deltax/deltax, tao/deltay/deltay, int((T[1] - T[0])/tao))
-
-
- #animation
- fig = plt.figure()
- ax = fig.add_subplot(111, projection='3d')
- X = x
- Y = y
- X, Y = np.meshgrid(X, Y)
-
- wframe = None
-
- iter = int((T[1] - T[0])/tao)
- A = createMatrix(m-2,n-2)
- for i in range(iter):
- A = computeOneIter(A,m,n,tao/deltax/deltax, tao/deltay/deltay)
- if wframe:
- ax.collections.remove(wframe)
-
- wframe = ax.plot_wireframe(X, Y, A, rstride=2, cstride=2)
- plt.pause(0.01)
- print 'iter=',i
-
-
- m = A.shape[0]
- n = A.shape[1]
-
-
- return A,x,y
-
-
-
- if __name__ == '__main__':
- getStart()
-

Copyright © 2003-2013 www.wpsshop.cn 版权所有,并保留所有权利。