PointMLP/analysis.py

22 lines
533 B
Python
Raw Normal View History

2022-04-19 15:37:20 +00:00
import fvcore.common
2023-08-03 14:40:14 +00:00
import fvcore.nn
import torch
2022-04-19 15:37:20 +00:00
from fvcore.nn import FlopCountAnalysis
2023-08-03 14:40:14 +00:00
2022-04-19 15:37:20 +00:00
from classification_ScanObjectNN.models import pointMLPElite
model = pointMLPElite()
2022-04-19 19:45:58 +00:00
model.eval()
2022-04-19 15:37:20 +00:00
# model = deit_tiny_patch16_224()
2023-08-03 14:40:14 +00:00
inputs = torch.randn((1, 3, 1024))
2022-04-19 15:37:20 +00:00
k = 1024.0
flops = FlopCountAnalysis(model, inputs).total()
print(f"Flops : {flops}")
2023-08-03 14:40:14 +00:00
flops = flops / (k**3)
2022-04-19 15:37:20 +00:00
print(f"Flops : {flops:.1f}G")
params = fvcore.nn.parameter_count(model)[""]
print(f"Params : {params}")
2023-08-03 14:40:14 +00:00
params = params / (k**2)
2022-04-19 15:37:20 +00:00
print(f"Params : {params:.1f}M")