Create analysis.py
This commit is contained in:
parent
c1d6235405
commit
09a60de6bc
19
analysis.py
Normal file
19
analysis.py
Normal file
|
@ -0,0 +1,19 @@
|
|||
import torch
|
||||
import fvcore.nn
|
||||
import fvcore.common
|
||||
from fvcore.nn import FlopCountAnalysis
|
||||
from classification_ScanObjectNN.models import pointMLPElite
|
||||
|
||||
model = pointMLPElite()
|
||||
# model = deit_tiny_patch16_224()
|
||||
|
||||
inputs = (torch.randn((1,3,1024)),)
|
||||
k = 1024.0
|
||||
flops = FlopCountAnalysis(model, inputs).total()
|
||||
print(f"Flops : {flops}")
|
||||
flops = flops/(k**3)
|
||||
print(f"Flops : {flops:.1f}G")
|
||||
params = fvcore.nn.parameter_count(model)[""]
|
||||
print(f"Params : {params}")
|
||||
params = params/(k**2)
|
||||
print(f"Params : {params:.1f}M")
|
Loading…
Reference in a new issue