Update analysis.py

This commit is contained in:
Xu Ma 2022-04-19 15:45:58 -04:00
parent 09a60de6bc
commit 5cb88d65e3

View file

@ -5,9 +5,10 @@ from fvcore.nn import FlopCountAnalysis
from classification_ScanObjectNN.models import pointMLPElite from classification_ScanObjectNN.models import pointMLPElite
model = pointMLPElite() model = pointMLPElite()
model.eval()
# model = deit_tiny_patch16_224() # model = deit_tiny_patch16_224()
inputs = (torch.randn((1,3,1024)),) inputs = (torch.randn((1,3,1024)))
k = 1024.0 k = 1024.0
flops = FlopCountAnalysis(model, inputs).total() flops = FlopCountAnalysis(model, inputs).total()
print(f"Flops : {flops}") print(f"Flops : {flops}")