1 2 3 4 5 6 7 8 9 10 11 12 13 14 15 16 17 18 19 20 21 22 23 24 25 26 27 28 29 30 31 32 33 34 35 36 37 38 39 40 41 42 43 44 45 46 47 48
| input_size = (270, 61) print('Input size = ', input_size)
print('----------- Start CNN -----------') cnn1 = cnn(input_size, kernelsize=(3), stride=(1), padding=(0)) cnn2 = cnn(cnn1, kernelsize=(3), stride=(1), padding=(0)) cnn3 = cnn(cnn2, kernelsize=(3), stride=(2), padding=(0)) cnn4 = cnn(cnn3, kernelsize=(3), stride=(1), padding=(0)) cnn5 = cnn(cnn4, kernelsize=(3), stride=(1), padding=(0)) cnn6 = cnn(cnn5, kernelsize=(3), stride=(2), padding=(0)) cnn7 = cnn(cnn6, kernelsize=(3), stride=(1), padding=(0)) cnn8 = cnn(cnn7, kernelsize=(3), stride=(1), padding=(0)) cnn9 = cnn(cnn8, kernelsize=(3), stride=(2), padding=(0)) cnn10 = cnn(cnn9, kernelsize=(3), stride=(1), padding=(0)) output_cnn = cnn(cnn10, kernelsize=(3,1), stride=(1), padding=(0)) print("Output of cnn:", output_cnn)
print('---------- Start TransCNN -------') transcnn1 = transcnn(output_cnn, kernelsize=(5,3), stride=(3,3), padding=(0)) transcnn1 = cnn(transcnn1, kernelsize=(3), stride=(1), padding=(1)) transcnn2 = transcnn(transcnn1, kernelsize=(5), stride=(2,3), padding=(0)) transcnn2 = cnn(transcnn2, kernelsize=(3), stride=(1), padding=(1)) transcnn3 = transcnn(transcnn2, kernelsize=(5,3), stride=(2,3), padding=(0)) transcnn3 = cnn(transcnn3, kernelsize=(3), stride=(1), padding=(1))
transcnn4 = transcnn(transcnn3, kernelsize=(5,3), stride=(2,3), padding=(0)) transcnn4 = cnn(transcnn4, kernelsize=(3), stride=(1), padding=(1))
transcnn5 = transcnn(transcnn4, kernelsize=(3), stride=(1,2), padding=(0)) transcnn5 = cnn(transcnn5, kernelsize=(3), stride=(1), padding=(1))
transcnn6 = transcnn(transcnn5, kernelsize=(3), stride=(1,2), padding=(0)) transcnn6 = cnn(transcnn6, kernelsize=(3), stride=(1), padding=(1))
transcnn7 = transcnn(transcnn6, kernelsize=(3,3), stride=(1,2), padding=(0)) transcnn7 = cnn(transcnn7, kernelsize=(3), stride=(1), padding=(1))
transcnn8 = transcnn(transcnn7, kernelsize=(6,3), stride=(1,1), padding=(0)) transcnn8 = cnn(transcnn8, kernelsize=(3), stride=(1), padding=(1))
transcnn9 = transcnn(transcnn8, kernelsize=(5,3), stride=(1,1), padding=(0)) transcnn9 = cnn(transcnn9, kernelsize=(3), stride=(1), padding=(1))
output_transcnn = cnn(transcnn9, kernelsize=(3), stride=(1), padding=(0)) print("Output of transcnn:", output_transcnn)
print("Output size = ", output_transcnn)
|