diff --git a/train.py b/train.py index ecad51d..200ccfa 100644 --- a/train.py +++ b/train.py @@ -89,7 +89,7 @@ def train_net(net, pbar.update(imgs.shape[0]) global_step += 1 - if global_step % (len(dataset) // (10 * batch_size)) == 0: + if global_step % (n_train // (10 * batch_size)) == 0: for tag, value in net.named_parameters(): tag = tag.replace('.', '/') writer.add_histogram('weights/' + tag, value.data.cpu().numpy(), global_step)