121 lines
4.4 KiB
Python
121 lines
4.4 KiB
Python
|
# Copyright (c) 2022, NVIDIA CORPORATION & AFFILIATES. All rights reserved.
|
||
|
#
|
||
|
# NVIDIA CORPORATION & AFFILIATES and its licensors retain all intellectual property
|
||
|
# and proprietary rights in and to this software, related documentation
|
||
|
# and any modifications thereto. Any use, reproduction, disclosure or
|
||
|
# distribution of this software and related documentation without an express
|
||
|
# license agreement from NVIDIA CORPORATION & AFFILIATES is strictly prohibited.
|
||
|
""" src: ddim/model/ema.py
|
||
|
implement the EMA model
|
||
|
usage:
|
||
|
ema_helper = EMAHelper(mu=self.config.model.ema_rate)
|
||
|
ema_helper.register(model)
|
||
|
ema_helper.load_state_dict(states[-1])
|
||
|
ema_helper.ema(model)
|
||
|
|
||
|
after optimizer.step()
|
||
|
ema_helper.update(model)
|
||
|
|
||
|
copied and modified from
|
||
|
https://github.com/NVlabs/LSGM/blob/5eae2f385c014f2250c3130152b6be711f6a3a5a/util/ema.py
|
||
|
"""
|
||
|
|
||
|
import warnings
|
||
|
import torch
|
||
|
from torch.optim import Optimizer
|
||
|
from loguru import logger
|
||
|
import torch.nn as nn
|
||
|
import os
|
||
|
|
||
|
|
||
|
class EMA(Optimizer):
|
||
|
def __init__(self, opt, ema_decay):
|
||
|
self.ema_decay = ema_decay
|
||
|
self.apply_ema = self.ema_decay > 0.
|
||
|
logger.info('[EMA] apply={}', self.apply_ema)
|
||
|
self.optimizer = opt
|
||
|
self.state = opt.state
|
||
|
self.param_groups = opt.param_groups
|
||
|
|
||
|
def zero_grad(self):
|
||
|
self.optimizer.zero_grad()
|
||
|
|
||
|
def step(self, *args, **kwargs):
|
||
|
retval = self.optimizer.step(*args, **kwargs)
|
||
|
|
||
|
# stop here if we are not applying EMA
|
||
|
if not self.apply_ema:
|
||
|
return retval
|
||
|
|
||
|
for group in self.optimizer.param_groups:
|
||
|
ema, params = {}, {}
|
||
|
for i, p in enumerate(group['params']):
|
||
|
if p.grad is None:
|
||
|
continue
|
||
|
state = self.optimizer.state[p]
|
||
|
|
||
|
# State initialization
|
||
|
if 'ema' not in state:
|
||
|
state['ema'] = p.data.clone()
|
||
|
|
||
|
if p.shape not in params:
|
||
|
params[p.shape] = {'idx': 0, 'data': []}
|
||
|
ema[p.shape] = []
|
||
|
|
||
|
params[p.shape]['data'].append(p.data)
|
||
|
ema[p.shape].append(state['ema'])
|
||
|
|
||
|
for i in params:
|
||
|
params[i]['data'] = torch.stack(params[i]['data'], dim=0)
|
||
|
ema[i] = torch.stack(ema[i], dim=0)
|
||
|
ema[i].mul_(self.ema_decay).add_(
|
||
|
params[i]['data'], alpha=1. - self.ema_decay)
|
||
|
|
||
|
for p in group['params']:
|
||
|
if p.grad is None:
|
||
|
continue
|
||
|
idx = params[p.shape]['idx']
|
||
|
self.optimizer.state[p]['ema'] = ema[p.shape][idx, :]
|
||
|
params[p.shape]['idx'] += 1
|
||
|
|
||
|
return retval
|
||
|
|
||
|
def load_state_dict(self, state_dict):
|
||
|
super(EMA, self).load_state_dict(state_dict)
|
||
|
# load_state_dict loads the data to self.state and self.param_groups. We need to pass this data to
|
||
|
# the underlying optimizer too.
|
||
|
# logger.info('state size: {}', len(self.state))
|
||
|
self.optimizer.state = self.state
|
||
|
self.optimizer.param_groups = self.param_groups
|
||
|
|
||
|
def swap_parameters_with_ema(self, store_params_in_ema):
|
||
|
""" This function swaps parameters with their ema values. It records original parameters in the ema
|
||
|
parameters, if store_params_in_ema is true."""
|
||
|
|
||
|
# stop here if we are not applying EMA
|
||
|
if not self.apply_ema:
|
||
|
warnings.warn(
|
||
|
'swap_parameters_with_ema was called when there is no EMA weights.')
|
||
|
return
|
||
|
logger.info('swap with ema')
|
||
|
count_no_found = 0
|
||
|
for group in self.optimizer.param_groups:
|
||
|
for i, p in enumerate(group['params']):
|
||
|
if not p.requires_grad:
|
||
|
# logger.info('no swap for i={}, param shape={}', i, p.shape)
|
||
|
continue
|
||
|
if p not in self.optimizer.state:
|
||
|
count_no_found += 1
|
||
|
# logger.info('no found i={}, {}/{} p {}', i,
|
||
|
# count_no_found, len(group['params']), p.shape)
|
||
|
continue
|
||
|
# if count_no_found > 100:
|
||
|
# logger.info('found: i={}, p={}', i, p.shape)
|
||
|
ema = self.optimizer.state[p]['ema']
|
||
|
if store_params_in_ema:
|
||
|
tmp = p.data.detach()
|
||
|
p.data = ema.detach()
|
||
|
self.optimizer.state[p]['ema'] = tmp
|
||
|
else:
|
||
|
p.data = ema.detach()
|