diff --git a/analysis.py b/analysis.py index 65404d5..9eb1022 100644 --- a/analysis.py +++ b/analysis.py @@ -5,9 +5,10 @@ from fvcore.nn import FlopCountAnalysis from classification_ScanObjectNN.models import pointMLPElite model = pointMLPElite() +model.eval() # model = deit_tiny_patch16_224() -inputs = (torch.randn((1,3,1024)),) +inputs = (torch.randn((1,3,1024))) k = 1024.0 flops = FlopCountAnalysis(model, inputs).total() print(f"Flops : {flops}")