LION/utils/checker.py
2023-01-23 00:14:49 -05:00

81 lines
2.8 KiB
Python

# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
#
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
# and proprietary rights in and to this software, related documentation
# and any modifications thereto. Any use, reproduction, disclosure or
# distribution of this software and related documentation without an express
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
import torch
def CHECKDIM(tensor, dim, val):
if type(tensor) == list:
for t in tensor:
CHECKDIM(t, dim, val)
else:
assert(len(tensor.shape) >= dim), 'expect {} to have {} dim shape {}'.format(tensor.shape, dim, val)
if type(val) is list:
assert(tensor.shape[dim] in val), 'expect {} to have {} dim shape {}'.format(
tensor.shape, dim, val)
else:
assert(tensor.shape[dim] == val), 'expect tensor with shape: {} having dim {} as {}'.format(
tensor.shape, dim, val)
return True
def CHECK5D(tensor, *args):
assert(len(tensor.shape) == 5), 'get {} {}'.format(tensor.shape, len(tensor.shape))
for t in args:
CHECK5D(t)
return tensor.shape
def CHECK3D(tensor, *args):
assert(len(tensor.shape) == 3), 'get {} {}'.format(tensor.shape, len(tensor.shape))
for t in args:
CHECK3D(t)
return tensor.shape
def CHECK4D(tensor):
assert(len(tensor.shape) == 4), 'get {} {}'.format(tensor.shape, len(tensor.shape))
return tensor.shape
def CHECKND(tensor, N):
assert(len(tensor.shape) == N), 'get tensor shape:{} DIM={}, expect:{}'.format(tensor.shape, len(tensor.shape), N)
return tensor.shape
def CHECK2D(tensor):
assert(len(tensor.shape) == 2), 'get {} {}'.format(tensor.shape, len(tensor.shape))
return tensor.shape
def CHECK_N3or6(input):
# expect input in shape (N,3) or (N,6)
CHECK_TENSOR(input)
CHECK2D(input)
assert(input.shape[1] == 3 or input.shape[1] == 6), f'expect shape N,3 or N,6; get {input.shape}'
return input.shape
def CHECK_N3or6or9(input):
# expect input in shape (N,3) or (N,6)
CHECK_TENSOR(input)
CHECK2D(input)
assert(input.shape[1] == 3 or input.shape[1] == 6 or input.shape[1] == 9), f'expect shape N,3 or N,6; get {input.shape}'
return input.shape
def CHECK_N3(input):
# expect input in shape (N,3)
CHECK_TENSOR(input)
CHECK2D(input)
CHECKDIM(input, dim=1, val=3)
return input.shape
def CHECK_TENSOR(input):
assert(torch.is_tensor(input)), f'expect tensor, get {type(input)}'
def CHECKEQ(a, b):
assert(a == b), f'expect a=b, get a={a} and b={b}'
def CHECKSIZE(t, values):
CHECKND(t, len(values))
for iv, vv in enumerate(values):
CHECKDIM(t, iv, vv)
def CHECKSAMESIZE(t1, t2):
CHECKSIZE(t1, t2.shape)