mirror of
https://github.com/Laurent2916/REVA-QCAV.git
synced 2024-11-08 14:39:00 +00:00
feat!: DETR
This commit is contained in:
parent
9d36719335
commit
8691735779
|
@ -1,16 +0,0 @@
|
|||
FROM nvidia/cuda:11.7.1-cudnn8-devel-ubuntu20.04
|
||||
|
||||
WORKDIR /workspace
|
||||
|
||||
ARG USERNAME=vscode
|
||||
ARG UID=1000
|
||||
ARG GID=${UID}
|
||||
|
||||
COPY library-scripts/*.sh library-scripts/*.env /tmp/library-scripts/
|
||||
RUN apt-get update \
|
||||
&& export DEBIAN_FRONTEND=noninteractive \
|
||||
&& /bin/bash /tmp/library-scripts/common-debian.sh \
|
||||
&& apt-get install -y python3 python3-pip && pip install --upgrade pip --no-input \
|
||||
&& apt-get autoremove -y && apt-get clean -y && rm -rf /var/lib/apt/lists/* /tmp/library-scripts
|
||||
|
||||
RUN su - vscode -c "curl -sSL https://install.python-poetry.org | python3 -"
|
|
@ -1,24 +0,0 @@
|
|||
{
|
||||
"name": "sphereDetect-dev",
|
||||
"dockerComposeFile": "docker-compose.yml",
|
||||
"service": "dev",
|
||||
"remoteUser": "vscode",
|
||||
"workspaceFolder": "/workspace",
|
||||
"postAttachCommand": "poetry install --with all",
|
||||
"extensions": [
|
||||
"ms-vscode-remote.remote-containers",
|
||||
"ms-azuretools.vscode-docker",
|
||||
"editorconfig.editorconfig",
|
||||
"njpwerner.autodocstring",
|
||||
"ms-python.python",
|
||||
"ms-toolsai.jupyter",
|
||||
"eamodio.gitlens"
|
||||
],
|
||||
"runArgs": [
|
||||
"--gpus",
|
||||
"all"
|
||||
],
|
||||
"forwardPorts": [
|
||||
8080
|
||||
]
|
||||
}
|
|
@ -1,28 +0,0 @@
|
|||
version: "3"
|
||||
|
||||
services:
|
||||
# development container
|
||||
dev:
|
||||
container_name: dev
|
||||
build:
|
||||
context: .
|
||||
dockerfile: Dockerfile
|
||||
volumes:
|
||||
- ..:/workspace
|
||||
stdin_open: true
|
||||
network_mode: service:wandb
|
||||
deploy:
|
||||
resources:
|
||||
reservations:
|
||||
devices:
|
||||
- capabilities:
|
||||
- gpu
|
||||
|
||||
# wandb dashboard
|
||||
wandb:
|
||||
hostname: wandb-local
|
||||
container_name: wandb-local
|
||||
image: wandb/local
|
||||
ports:
|
||||
- 8080:8080
|
||||
restart: unless-stopped
|
|
@ -1,454 +0,0 @@
|
|||
#!/usr/bin/env bash
|
||||
#-------------------------------------------------------------------------------------------------------------
|
||||
# Copyright (c) Microsoft Corporation. All rights reserved.
|
||||
# Licensed under the MIT License. See https://go.microsoft.com/fwlink/?linkid=2090316 for license information.
|
||||
#-------------------------------------------------------------------------------------------------------------
|
||||
#
|
||||
# Docs: https://github.com/microsoft/vscode-dev-containers/blob/main/script-library/docs/common.md
|
||||
# Maintainer: The VS Code and Codespaces Teams
|
||||
|
||||
set -e
|
||||
|
||||
INSTALL_ZSH=${INSTALLZSH:-"true"}
|
||||
INSTALL_OH_MY_ZSH=${INSTALLOHMYZSH:-"true"}
|
||||
UPGRADE_PACKAGES=${UPGRADEPACKAGES:-"true"}
|
||||
USERNAME=${USERNAME:-"automatic"}
|
||||
USER_UID=${UID:-"automatic"}
|
||||
USER_GID=${GID:-"automatic"}
|
||||
ADD_NON_FREE_PACKAGES=${NONFREEPACKAGES:-"false"}
|
||||
|
||||
DEV_CONTAINERS_DIR="/usr/local/etc/vscode-dev-containers"
|
||||
MARKER_FILE="${DEV_CONTAINERS_DIR}/common"
|
||||
|
||||
if [ "$(id -u)" -ne 0 ]; then
|
||||
echo -e 'Script must be run as root. Use sudo, su, or add "USER root" to your Dockerfile before running this script.'
|
||||
exit 1
|
||||
fi
|
||||
|
||||
# Ensure that login shells get the correct path if the user updated the PATH using ENV.
|
||||
rm -f /etc/profile.d/00-restore-env.sh
|
||||
echo "export PATH=${PATH//$(sh -lc 'echo $PATH')/\$PATH}" >/etc/profile.d/00-restore-env.sh
|
||||
chmod +x /etc/profile.d/00-restore-env.sh
|
||||
|
||||
# If in automatic mode, determine if a user already exists, if not use vscode
|
||||
if [ "${USERNAME}" = "auto" ] || [ "${USERNAME}" = "automatic" ]; then
|
||||
USERNAME=""
|
||||
POSSIBLE_USERS=("vscode" "node" "codespace" "$(awk -v val=1000 -F ":" '$3==val{print $1}' /etc/passwd)")
|
||||
for CURRENT_USER in "${POSSIBLE_USERS[@]}"; do
|
||||
if id -u ${CURRENT_USER} >/dev/null 2>&1; then
|
||||
USERNAME=${CURRENT_USER}
|
||||
break
|
||||
fi
|
||||
done
|
||||
if [ "${USERNAME}" = "" ]; then
|
||||
USERNAME=vscode
|
||||
fi
|
||||
elif [ "${USERNAME}" = "none" ]; then
|
||||
USERNAME=root
|
||||
USER_UID=0
|
||||
USER_GID=0
|
||||
fi
|
||||
|
||||
# Load markers to see which steps have already run
|
||||
if [ -f "${MARKER_FILE}" ]; then
|
||||
echo "Marker file found:"
|
||||
cat "${MARKER_FILE}"
|
||||
source "${MARKER_FILE}"
|
||||
fi
|
||||
|
||||
# Ensure apt is in non-interactive to avoid prompts
|
||||
export DEBIAN_FRONTEND=noninteractive
|
||||
|
||||
apt_get_update() {
|
||||
echo "Running apt-get update..."
|
||||
apt-get update -y
|
||||
}
|
||||
|
||||
# Run install apt-utils to avoid debconf warning then verify presence of other common developer tools and dependencies
|
||||
if [ "${PACKAGES_ALREADY_INSTALLED}" != "true" ]; then
|
||||
|
||||
package_list="apt-utils \
|
||||
openssh-client \
|
||||
gnupg2 \
|
||||
dirmngr \
|
||||
iproute2 \
|
||||
procps \
|
||||
lsof \
|
||||
htop \
|
||||
net-tools \
|
||||
psmisc \
|
||||
curl \
|
||||
tree \
|
||||
wget \
|
||||
rsync \
|
||||
ca-certificates \
|
||||
unzip \
|
||||
bzip2 \
|
||||
zip \
|
||||
nano \
|
||||
vim-tiny \
|
||||
less \
|
||||
jq \
|
||||
lsb-release \
|
||||
apt-transport-https \
|
||||
dialog \
|
||||
libc6 \
|
||||
libgcc1 \
|
||||
libkrb5-3 \
|
||||
libgssapi-krb5-2 \
|
||||
libicu[0-9][0-9] \
|
||||
liblttng-ust[0-9] \
|
||||
libstdc++6 \
|
||||
zlib1g \
|
||||
locales \
|
||||
sudo \
|
||||
ncdu \
|
||||
man-db \
|
||||
strace \
|
||||
manpages \
|
||||
manpages-dev \
|
||||
init-system-helpers"
|
||||
|
||||
# Needed for adding manpages-posix and manpages-posix-dev which are non-free packages in Debian
|
||||
if [ "${ADD_NON_FREE_PACKAGES}" = "true" ]; then
|
||||
# Bring in variables from /etc/os-release like VERSION_CODENAME
|
||||
. /etc/os-release
|
||||
sed -i -E "s/deb http:\/\/(deb|httpredir)\.debian\.org\/debian ${VERSION_CODENAME} main/deb http:\/\/\1\.debian\.org\/debian ${VERSION_CODENAME} main contrib non-free/" /etc/apt/sources.list
|
||||
sed -i -E "s/deb-src http:\/\/(deb|httredir)\.debian\.org\/debian ${VERSION_CODENAME} main/deb http:\/\/\1\.debian\.org\/debian ${VERSION_CODENAME} main contrib non-free/" /etc/apt/sources.list
|
||||
sed -i -E "s/deb http:\/\/(deb|httpredir)\.debian\.org\/debian ${VERSION_CODENAME}-updates main/deb http:\/\/\1\.debian\.org\/debian ${VERSION_CODENAME}-updates main contrib non-free/" /etc/apt/sources.list
|
||||
sed -i -E "s/deb-src http:\/\/(deb|httpredir)\.debian\.org\/debian ${VERSION_CODENAME}-updates main/deb http:\/\/\1\.debian\.org\/debian ${VERSION_CODENAME}-updates main contrib non-free/" /etc/apt/sources.list
|
||||
sed -i "s/deb http:\/\/security\.debian\.org\/debian-security ${VERSION_CODENAME}\/updates main/deb http:\/\/security\.debian\.org\/debian-security ${VERSION_CODENAME}\/updates main contrib non-free/" /etc/apt/sources.list
|
||||
sed -i "s/deb-src http:\/\/security\.debian\.org\/debian-security ${VERSION_CODENAME}\/updates main/deb http:\/\/security\.debian\.org\/debian-security ${VERSION_CODENAME}\/updates main contrib non-free/" /etc/apt/sources.list
|
||||
sed -i "s/deb http:\/\/deb\.debian\.org\/debian ${VERSION_CODENAME}-backports main/deb http:\/\/deb\.debian\.org\/debian ${VERSION_CODENAME}-backports main contrib non-free/" /etc/apt/sources.list
|
||||
sed -i "s/deb-src http:\/\/deb\.debian\.org\/debian ${VERSION_CODENAME}-backports main/deb http:\/\/deb\.debian\.org\/debian ${VERSION_CODENAME}-backports main contrib non-free/" /etc/apt/sources.list
|
||||
# Handle bullseye location for security https://www.debian.org/releases/bullseye/amd64/release-notes/ch-information.en.html
|
||||
sed -i "s/deb http:\/\/security\.debian\.org\/debian-security ${VERSION_CODENAME}-security main/deb http:\/\/security\.debian\.org\/debian-security ${VERSION_CODENAME}-security main contrib non-free/" /etc/apt/sources.list
|
||||
sed -i "s/deb-src http:\/\/security\.debian\.org\/debian-security ${VERSION_CODENAME}-security main/deb http:\/\/security\.debian\.org\/debian-security ${VERSION_CODENAME}-security main contrib non-free/" /etc/apt/sources.list
|
||||
echo "Running apt-get update..."
|
||||
apt-get update
|
||||
package_list="${package_list} manpages-posix manpages-posix-dev"
|
||||
else
|
||||
apt_get_update
|
||||
fi
|
||||
|
||||
# Install libssl1.1 if available
|
||||
if [[ ! -z $(apt-cache --names-only search ^libssl1.1$) ]]; then
|
||||
package_list="${package_list} libssl1.1"
|
||||
fi
|
||||
|
||||
# Install appropriate version of libssl1.0.x if available
|
||||
libssl_package=$(dpkg-query -f '${db:Status-Abbrev}\t${binary:Package}\n' -W 'libssl1\.0\.?' 2>&1 || echo '')
|
||||
if [ "$(echo "$LIlibssl_packageBSSL" | grep -o 'libssl1\.0\.[0-9]:' | uniq | sort | wc -l)" -eq 0 ]; then
|
||||
if [[ ! -z $(apt-cache --names-only search ^libssl1.0.2$) ]]; then
|
||||
# Debian 9
|
||||
package_list="${package_list} libssl1.0.2"
|
||||
elif [[ ! -z $(apt-cache --names-only search ^libssl1.0.0$) ]]; then
|
||||
# Ubuntu 18.04, 16.04, earlier
|
||||
package_list="${package_list} libssl1.0.0"
|
||||
fi
|
||||
fi
|
||||
|
||||
echo "Packages to verify are installed: ${package_list}"
|
||||
apt-get -y install --no-install-recommends ${package_list} 2> >(grep -v 'debconf: delaying package configuration, since apt-utils is not installed' >&2)
|
||||
|
||||
# Install git if not already installed (may be more recent than distro version)
|
||||
if ! type git >/dev/null 2>&1; then
|
||||
apt-get -y install --no-install-recommends git
|
||||
fi
|
||||
|
||||
PACKAGES_ALREADY_INSTALLED="true"
|
||||
fi
|
||||
|
||||
# Get to latest versions of all packages
|
||||
if [ "${UPGRADE_PACKAGES}" = "true" ]; then
|
||||
apt_get_update
|
||||
apt-get -y upgrade --no-install-recommends
|
||||
apt-get autoremove -y
|
||||
fi
|
||||
|
||||
# Ensure at least the en_US.UTF-8 UTF-8 locale is available.
|
||||
# Common need for both applications and things like the agnoster ZSH theme.
|
||||
if [ "${LOCALE_ALREADY_SET}" != "true" ] && ! grep -o -E '^\s*en_US.UTF-8\s+UTF-8' /etc/locale.gen >/dev/null; then
|
||||
echo "en_US.UTF-8 UTF-8" >>/etc/locale.gen
|
||||
locale-gen
|
||||
LOCALE_ALREADY_SET="true"
|
||||
fi
|
||||
|
||||
# Create or update a non-root user to match UID/GID.
|
||||
group_name="${USERNAME}"
|
||||
if id -u ${USERNAME} >/dev/null 2>&1; then
|
||||
# User exists, update if needed
|
||||
if [ "${USER_GID}" != "automatic" ] && [ "$USER_GID" != "$(id -g $USERNAME)" ]; then
|
||||
group_name="$(id -gn $USERNAME)"
|
||||
groupmod --gid $USER_GID ${group_name}
|
||||
usermod --gid $USER_GID $USERNAME
|
||||
fi
|
||||
if [ "${USER_UID}" != "automatic" ] && [ "$USER_UID" != "$(id -u $USERNAME)" ]; then
|
||||
usermod --uid $USER_UID $USERNAME
|
||||
fi
|
||||
else
|
||||
# Create user
|
||||
if [ "${USER_GID}" = "automatic" ]; then
|
||||
groupadd $USERNAME
|
||||
else
|
||||
groupadd --gid $USER_GID $USERNAME
|
||||
fi
|
||||
if [ "${USER_UID}" = "automatic" ]; then
|
||||
useradd -s /bin/bash --gid $USERNAME -m $USERNAME
|
||||
else
|
||||
useradd -s /bin/bash --uid $USER_UID --gid $USERNAME -m $USERNAME
|
||||
fi
|
||||
fi
|
||||
|
||||
# Add add sudo support for non-root user
|
||||
if [ "${USERNAME}" != "root" ] && [ "${EXISTING_NON_ROOT_USER}" != "${USERNAME}" ]; then
|
||||
echo $USERNAME ALL=\(root\) NOPASSWD:ALL >/etc/sudoers.d/$USERNAME
|
||||
chmod 0440 /etc/sudoers.d/$USERNAME
|
||||
EXISTING_NON_ROOT_USER="${USERNAME}"
|
||||
fi
|
||||
|
||||
# ** Shell customization section **
|
||||
if [ "${USERNAME}" = "root" ]; then
|
||||
user_rc_path="/root"
|
||||
else
|
||||
user_rc_path="/home/${USERNAME}"
|
||||
fi
|
||||
|
||||
# Restore user .bashrc defaults from skeleton file if it doesn't exist or is empty
|
||||
if [ ! -f "${user_rc_path}/.bashrc" ] || [ ! -s "${user_rc_path}/.bashrc" ]; then
|
||||
cp /etc/skel/.bashrc "${user_rc_path}/.bashrc"
|
||||
fi
|
||||
|
||||
# Restore user .profile defaults from skeleton file if it doesn't exist or is empty
|
||||
if [ ! -f "${user_rc_path}/.profile" ] || [ ! -s "${user_rc_path}/.profile" ]; then
|
||||
cp /etc/skel/.profile "${user_rc_path}/.profile"
|
||||
fi
|
||||
|
||||
# .bashrc/.zshrc snippet
|
||||
rc_snippet="$(
|
||||
cat <<'EOF'
|
||||
|
||||
if [ -z "${USER}" ]; then export USER=$(whoami); fi
|
||||
if [[ "${PATH}" != *"$HOME/.local/bin"* ]]; then export PATH="${PATH}:$HOME/.local/bin"; fi
|
||||
|
||||
# Display optional first run image specific notice if configured and terminal is interactive
|
||||
if [ -t 1 ] && [[ "${TERM_PROGRAM}" = "vscode" || "${TERM_PROGRAM}" = "codespaces" ]] && [ ! -f "$HOME/.config/vscode-dev-containers/first-run-notice-already-displayed" ]; then
|
||||
if [ -f "${DEV_CONTAINERS_DIR}/first-run-notice.txt" ]; then
|
||||
cat "${DEV_CONTAINERS_DIR}/first-run-notice.txt"
|
||||
elif [ -f "/workspaces/.codespaces/shared/first-run-notice.txt" ]; then
|
||||
cat "/workspaces/.codespaces/shared/first-run-notice.txt"
|
||||
fi
|
||||
mkdir -p "$HOME/.config/vscode-dev-containers"
|
||||
# Mark first run notice as displayed after 10s to avoid problems with fast terminal refreshes hiding it
|
||||
((sleep 10s; touch "$HOME/.config/vscode-dev-containers/first-run-notice-already-displayed") &)
|
||||
fi
|
||||
|
||||
# Set the default git editor if not already set
|
||||
if [ -z "$(git config --get core.editor)" ] && [ -z "${GIT_EDITOR}" ]; then
|
||||
if [ "${TERM_PROGRAM}" = "vscode" ]; then
|
||||
if [[ -n $(command -v code-insiders) && -z $(command -v code) ]]; then
|
||||
export GIT_EDITOR="code-insiders --wait"
|
||||
else
|
||||
export GIT_EDITOR="code --wait"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
EOF
|
||||
)"
|
||||
|
||||
# code shim, it fallbacks to code-insiders if code is not available
|
||||
cat <<'EOF' >/usr/local/bin/code
|
||||
#!/bin/sh
|
||||
|
||||
get_in_path_except_current() {
|
||||
which -a "$1" | grep -A1 "$0" | grep -v "$0"
|
||||
}
|
||||
|
||||
code="$(get_in_path_except_current code)"
|
||||
|
||||
if [ -n "$code" ]; then
|
||||
exec "$code" "$@"
|
||||
elif [ "$(command -v code-insiders)" ]; then
|
||||
exec code-insiders "$@"
|
||||
else
|
||||
echo "code or code-insiders is not installed" >&2
|
||||
exit 127
|
||||
fi
|
||||
EOF
|
||||
chmod +x /usr/local/bin/code
|
||||
|
||||
# systemctl shim - tells people to use 'service' if systemd is not running
|
||||
cat <<'EOF' >/usr/local/bin/systemctl
|
||||
#!/bin/sh
|
||||
set -e
|
||||
if [ -d "/run/systemd/system" ]; then
|
||||
exec /bin/systemctl/systemctl "$@"
|
||||
else
|
||||
echo '\n"systemd" is not running in this container due to its overhead.\nUse the "service" command to start services instead. e.g.: \n\nservice --status-all'
|
||||
fi
|
||||
EOF
|
||||
chmod +x /usr/local/bin/systemctl
|
||||
|
||||
# Codespaces bash and OMZ themes - partly inspired by https://github.com/ohmyzsh/ohmyzsh/blob/master/themes/robbyrussell.zsh-theme
|
||||
codespaces_bash="$(
|
||||
cat \
|
||||
<<'EOF'
|
||||
|
||||
# Codespaces bash prompt theme
|
||||
__bash_prompt() {
|
||||
local userpart='`export XIT=$? \
|
||||
&& [ ! -z "${GITHUB_USER}" ] && echo -n "\[\033[0;32m\]@${GITHUB_USER} " || echo -n "\[\033[0;32m\]\u " \
|
||||
&& [ "$XIT" -ne "0" ] && echo -n "\[\033[1;31m\]➜" || echo -n "\[\033[0m\]➜"`'
|
||||
local gitbranch='`\
|
||||
if [ "$(git config --get codespaces-theme.hide-status 2>/dev/null)" != 1 ]; then \
|
||||
export BRANCH=$(git symbolic-ref --short HEAD 2>/dev/null || git rev-parse --short HEAD 2>/dev/null); \
|
||||
if [ "${BRANCH}" != "" ]; then \
|
||||
echo -n "\[\033[0;36m\](\[\033[1;31m\]${BRANCH}" \
|
||||
&& if git ls-files --error-unmatch -m --directory --no-empty-directory -o --exclude-standard ":/*" > /dev/null 2>&1; then \
|
||||
echo -n " \[\033[1;33m\]✗"; \
|
||||
fi \
|
||||
&& echo -n "\[\033[0;36m\]) "; \
|
||||
fi; \
|
||||
fi`'
|
||||
local lightblue='\[\033[1;34m\]'
|
||||
local removecolor='\[\033[0m\]'
|
||||
PS1="${userpart} ${lightblue}\w ${gitbranch}${removecolor}\$ "
|
||||
unset -f __bash_prompt
|
||||
}
|
||||
__bash_prompt
|
||||
|
||||
EOF
|
||||
)"
|
||||
|
||||
codespaces_zsh="$(
|
||||
cat \
|
||||
<<'EOF'
|
||||
# Codespaces zsh prompt theme
|
||||
__zsh_prompt() {
|
||||
local prompt_username
|
||||
if [ ! -z "${GITHUB_USER}" ]; then
|
||||
prompt_username="@${GITHUB_USER}"
|
||||
else
|
||||
prompt_username="%n"
|
||||
fi
|
||||
PROMPT="%{$fg[green]%}${prompt_username} %(?:%{$reset_color%}➜ :%{$fg_bold[red]%}➜ )" # User/exit code arrow
|
||||
PROMPT+='%{$fg_bold[blue]%}%(5~|%-1~/…/%3~|%4~)%{$reset_color%} ' # cwd
|
||||
PROMPT+='$([ "$(git config --get codespaces-theme.hide-status 2>/dev/null)" != 1 ] && git_prompt_info)' # Git status
|
||||
PROMPT+='%{$fg[white]%}$ %{$reset_color%}'
|
||||
unset -f __zsh_prompt
|
||||
}
|
||||
ZSH_THEME_GIT_PROMPT_PREFIX="%{$fg_bold[cyan]%}(%{$fg_bold[red]%}"
|
||||
ZSH_THEME_GIT_PROMPT_SUFFIX="%{$reset_color%} "
|
||||
ZSH_THEME_GIT_PROMPT_DIRTY=" %{$fg_bold[yellow]%}✗%{$fg_bold[cyan]%})"
|
||||
ZSH_THEME_GIT_PROMPT_CLEAN="%{$fg_bold[cyan]%})"
|
||||
__zsh_prompt
|
||||
|
||||
EOF
|
||||
)"
|
||||
|
||||
# Add RC snippet and custom bash prompt
|
||||
if [ "${RC_SNIPPET_ALREADY_ADDED}" != "true" ]; then
|
||||
echo "${rc_snippet}" >>/etc/bash.bashrc
|
||||
echo "${codespaces_bash}" >>"${user_rc_path}/.bashrc"
|
||||
echo 'export PROMPT_DIRTRIM=4' >>"${user_rc_path}/.bashrc"
|
||||
if [ "${USERNAME}" != "root" ]; then
|
||||
echo "${codespaces_bash}" >>"/root/.bashrc"
|
||||
echo 'export PROMPT_DIRTRIM=4' >>"/root/.bashrc"
|
||||
fi
|
||||
chown ${USERNAME}:${group_name} "${user_rc_path}/.bashrc"
|
||||
RC_SNIPPET_ALREADY_ADDED="true"
|
||||
fi
|
||||
|
||||
# Optionally install and configure zsh and Oh My Zsh!
|
||||
if [ "${INSTALL_ZSH}" = "true" ]; then
|
||||
if ! type zsh >/dev/null 2>&1; then
|
||||
apt_get_update
|
||||
apt-get install -y zsh
|
||||
fi
|
||||
if [ "${ZSH_ALREADY_INSTALLED}" != "true" ]; then
|
||||
echo "${rc_snippet}" >>/etc/zsh/zshrc
|
||||
ZSH_ALREADY_INSTALLED="true"
|
||||
fi
|
||||
|
||||
# Adapted, simplified inline Oh My Zsh! install steps that adds, defaults to a codespaces theme.
|
||||
# See https://github.com/ohmyzsh/ohmyzsh/blob/master/tools/install.sh for official script.
|
||||
oh_my_install_dir="${user_rc_path}/.oh-my-zsh"
|
||||
if [ ! -d "${oh_my_install_dir}" ] && [ "${INSTALL_OH_MY_ZSH}" = "true" ]; then
|
||||
template_path="${oh_my_install_dir}/templates/zshrc.zsh-template"
|
||||
user_rc_file="${user_rc_path}/.zshrc"
|
||||
umask g-w,o-w
|
||||
mkdir -p ${oh_my_install_dir}
|
||||
git clone --depth=1 \
|
||||
-c core.eol=lf \
|
||||
-c core.autocrlf=false \
|
||||
-c fsck.zeroPaddedFilemode=ignore \
|
||||
-c fetch.fsck.zeroPaddedFilemode=ignore \
|
||||
-c receive.fsck.zeroPaddedFilemode=ignore \
|
||||
"https://github.com/ohmyzsh/ohmyzsh" "${oh_my_install_dir}" 2>&1
|
||||
echo -e "$(cat "${template_path}")\nDISABLE_AUTO_UPDATE=true\nDISABLE_UPDATE_PROMPT=true" >${user_rc_file}
|
||||
sed -i -e 's/ZSH_THEME=.*/ZSH_THEME="codespaces"/g' ${user_rc_file}
|
||||
|
||||
mkdir -p ${oh_my_install_dir}/custom/themes
|
||||
echo "${codespaces_zsh}" >"${oh_my_install_dir}/custom/themes/codespaces.zsh-theme"
|
||||
# Shrink git while still enabling updates
|
||||
cd "${oh_my_install_dir}"
|
||||
git repack -a -d -f --depth=1 --window=1
|
||||
# Copy to non-root user if one is specified
|
||||
if [ "${USERNAME}" != "root" ]; then
|
||||
cp -rf "${user_rc_file}" "${oh_my_install_dir}" /root
|
||||
chown -R ${USERNAME}:${group_name} "${user_rc_path}"
|
||||
fi
|
||||
fi
|
||||
fi
|
||||
|
||||
# Persist image metadata info, script if meta.env found in same directory
|
||||
meta_info_script="$(
|
||||
cat <<'EOF'
|
||||
#!/bin/sh
|
||||
. /usr/local/etc/vscode-dev-containers/meta.env
|
||||
|
||||
# Minimal output
|
||||
if [ "$1" = "version" ] || [ "$1" = "image-version" ]; then
|
||||
echo "${VERSION}"
|
||||
exit 0
|
||||
elif [ "$1" = "release" ]; then
|
||||
echo "${GIT_REPOSITORY_RELEASE}"
|
||||
exit 0
|
||||
elif [ "$1" = "content" ] || [ "$1" = "content-url" ] || [ "$1" = "contents" ] || [ "$1" = "contents-url" ]; then
|
||||
echo "${CONTENTS_URL}"
|
||||
exit 0
|
||||
fi
|
||||
|
||||
#Full output
|
||||
echo
|
||||
echo "Development container image information"
|
||||
echo
|
||||
if [ ! -z "${VERSION}" ]; then echo "- Image version: ${VERSION}"; fi
|
||||
if [ ! -z "${DEFINITION_ID}" ]; then echo "- Definition ID: ${DEFINITION_ID}"; fi
|
||||
if [ ! -z "${VARIANT}" ]; then echo "- Variant: ${VARIANT}"; fi
|
||||
if [ ! -z "${GIT_REPOSITORY}" ]; then echo "- Source code repository: ${GIT_REPOSITORY}"; fi
|
||||
if [ ! -z "${GIT_REPOSITORY_RELEASE}" ]; then echo "- Source code release/branch: ${GIT_REPOSITORY_RELEASE}"; fi
|
||||
if [ ! -z "${BUILD_TIMESTAMP}" ]; then echo "- Timestamp: ${BUILD_TIMESTAMP}"; fi
|
||||
if [ ! -z "${CONTENTS_URL}" ]; then echo && echo "More info: ${CONTENTS_URL}"; fi
|
||||
echo
|
||||
EOF
|
||||
)"
|
||||
if [ -f "${DEV_CONTAINERS_DIR}/meta.env" ]; then
|
||||
echo "${meta_info_script}" >/usr/local/bin/devcontainer-info
|
||||
chmod +x /usr/local/bin/devcontainer-info
|
||||
fi
|
||||
|
||||
if [ ! -d "${DEV_CONTAINERS_DIR}" ]; then
|
||||
mkdir -p "$(dirname "${MARKER_FILE}")"
|
||||
fi
|
||||
|
||||
# Write marker file
|
||||
echo -e "\
|
||||
PACKAGES_ALREADY_INSTALLED=${PACKAGES_ALREADY_INSTALLED}\n\
|
||||
LOCALE_ALREADY_SET=${LOCALE_ALREADY_SET}\n\
|
||||
EXISTING_NON_ROOT_USER=${EXISTING_NON_ROOT_USER}\n\
|
||||
RC_SNIPPET_ALREADY_ADDED=${RC_SNIPPET_ALREADY_ADDED}\n\
|
||||
ZSH_ALREADY_INSTALLED=${ZSH_ALREADY_INSTALLED}" >"${MARKER_FILE}"
|
||||
|
||||
echo "Done!"
|
179
.gitignore
vendored
179
.gitignore
vendored
|
@ -1,177 +1,8 @@
|
|||
# Personnal ignores
|
||||
wandb/
|
||||
wandb-local/
|
||||
data/
|
||||
dataset*/
|
||||
*.parquet
|
||||
.venv/
|
||||
lightning_logs/
|
||||
|
||||
checkpoints/
|
||||
*.pth
|
||||
*.onnx
|
||||
*.ckpt
|
||||
|
||||
images/
|
||||
*.png
|
||||
*.jpg
|
||||
|
||||
# https://github.com/github/gitignore/blob/main/Python.gitignore
|
||||
# Basic .gitignore for a python repo.
|
||||
|
||||
# Byte-compiled / optimized / DLL files
|
||||
__pycache__/
|
||||
*.py[cod]
|
||||
*$py.class
|
||||
|
||||
# C extensions
|
||||
*.so
|
||||
|
||||
# Distribution / packaging
|
||||
.Python
|
||||
build/
|
||||
develop-eggs/
|
||||
dist/
|
||||
downloads/
|
||||
eggs/
|
||||
.eggs/
|
||||
lib/
|
||||
lib64/
|
||||
parts/
|
||||
sdist/
|
||||
var/
|
||||
wheels/
|
||||
share/python-wheels/
|
||||
*.egg-info/
|
||||
.installed.cfg
|
||||
*.egg
|
||||
MANIFEST
|
||||
|
||||
# PyInstaller
|
||||
# Usually these files are written by a python script from a template
|
||||
# before PyInstaller builds the exe, so as to inject date/other infos into it.
|
||||
*.manifest
|
||||
*.spec
|
||||
|
||||
# Installer logs
|
||||
pip-log.txt
|
||||
pip-delete-this-directory.txt
|
||||
|
||||
# Unit test / coverage reports
|
||||
htmlcov/
|
||||
.tox/
|
||||
.nox/
|
||||
.coverage
|
||||
.coverage.*
|
||||
.cache
|
||||
nosetests.xml
|
||||
coverage.xml
|
||||
*.cover
|
||||
*.py,cover
|
||||
.hypothesis/
|
||||
.pytest_cache/
|
||||
cover/
|
||||
|
||||
# Translations
|
||||
*.mo
|
||||
*.pot
|
||||
|
||||
# Django stuff:
|
||||
*.log
|
||||
local_settings.py
|
||||
db.sqlite3
|
||||
db.sqlite3-journal
|
||||
|
||||
# Flask stuff:
|
||||
instance/
|
||||
.webassets-cache
|
||||
|
||||
# Scrapy stuff:
|
||||
.scrapy
|
||||
|
||||
# Sphinx documentation
|
||||
docs/_build/
|
||||
|
||||
# PyBuilder
|
||||
.pybuilder/
|
||||
target/
|
||||
|
||||
# Jupyter Notebook
|
||||
.ipynb_checkpoints
|
||||
|
||||
# IPython
|
||||
profile_default/
|
||||
ipython_config.py
|
||||
|
||||
# pyenv
|
||||
# For a library or package, you might want to ignore these files since the code is
|
||||
# intended to run in multiple environments; otherwise, check them in:
|
||||
# .python-version
|
||||
|
||||
# pipenv
|
||||
# According to pypa/pipenv#598, it is recommended to include Pipfile.lock in version control.
|
||||
# However, in case of collaboration, if having platform-specific dependencies or dependencies
|
||||
# having no cross-platform support, pipenv may install dependencies that don't work, or not
|
||||
# install all needed dependencies.
|
||||
#Pipfile.lock
|
||||
|
||||
# poetry
|
||||
# Similar to Pipfile.lock, it is generally recommended to include poetry.lock in version control.
|
||||
# This is especially recommended for binary packages to ensure reproducibility, and is more
|
||||
# commonly ignored for libraries.
|
||||
# https://python-poetry.org/docs/basic-usage/#commit-your-poetrylock-file-to-version-control
|
||||
#poetry.lock
|
||||
|
||||
# pdm
|
||||
# Similar to Pipfile.lock, it is generally recommended to include pdm.lock in version control.
|
||||
#pdm.lock
|
||||
# pdm stores project-wide configurations in .pdm.toml, but it is recommended to not include it
|
||||
# in version control.
|
||||
# https://pdm.fming.dev/#use-with-ide
|
||||
.pdm.toml
|
||||
|
||||
# PEP 582; used by e.g. github.com/David-OConnor/pyflow and github.com/pdm-project/pdm
|
||||
__pypackages__/
|
||||
|
||||
# Celery stuff
|
||||
celerybeat-schedule
|
||||
celerybeat.pid
|
||||
|
||||
# SageMath parsed files
|
||||
*.sage.py
|
||||
|
||||
# Environments
|
||||
.env
|
||||
.venv
|
||||
env/
|
||||
venv/
|
||||
ENV/
|
||||
env.bak/
|
||||
venv.bak/
|
||||
|
||||
# Spyder project settings
|
||||
.spyderproject
|
||||
.spyproject
|
||||
|
||||
# Rope project settings
|
||||
.ropeproject
|
||||
|
||||
# mkdocs documentation
|
||||
/site
|
||||
|
||||
# mypy
|
||||
.mypy_cache/
|
||||
.dmypy.json
|
||||
dmypy.json
|
||||
|
||||
# Pyre type checker
|
||||
.pyre/
|
||||
|
||||
# pytype static type analyzer
|
||||
.pytype/
|
||||
|
||||
# Cython debug symbols
|
||||
cython_debug/
|
||||
|
||||
# PyCharm
|
||||
# JetBrains specific template is maintained in a separate JetBrains.gitignore that can
|
||||
# be found at https://github.com/github/gitignore/blob/main/Global/JetBrains.gitignore
|
||||
# and can be added to the global gitignore or merged into this file. For a more nuclear
|
||||
# option (not recommended) you can uncomment the following to ignore the entire idea folder.
|
||||
#.idea/
|
||||
*.jpg
|
||||
|
|
|
@ -1,46 +0,0 @@
|
|||
repos:
|
||||
- repo: https://github.com/asottile/pyupgrade
|
||||
rev: "v2.37.3"
|
||||
hooks:
|
||||
- id: pyupgrade
|
||||
|
||||
- repo: https://github.com/python-poetry/poetry
|
||||
rev: "1.2.0rc1"
|
||||
hooks:
|
||||
- id: poetry-check
|
||||
- id: poetry-lock
|
||||
|
||||
- repo: https://github.com/pre-commit/pre-commit-hooks
|
||||
rev: "v4.3.0"
|
||||
hooks:
|
||||
# - id: check-added-large-files
|
||||
- id: check-executables-have-shebangs
|
||||
- id: check-merge-conflict
|
||||
- id: check-symlinks
|
||||
# - id: check-json
|
||||
- id: check-toml
|
||||
- id: check-yaml
|
||||
- id: debug-statements
|
||||
- id: destroyed-symlinks
|
||||
- id: detect-private-key
|
||||
- id: end-of-file-fixer
|
||||
- id: fix-byte-order-marker
|
||||
- id: mixed-line-ending
|
||||
- id: trailing-whitespace
|
||||
|
||||
- repo: https://github.com/pre-commit/mirrors-mypy
|
||||
rev: "v0.971"
|
||||
hooks:
|
||||
- id: mypy
|
||||
|
||||
- repo: https://github.com/pycqa/isort
|
||||
rev: "5.10.1"
|
||||
hooks:
|
||||
- id: isort
|
||||
name: isort (python)
|
||||
|
||||
- repo: https://github.com/psf/black
|
||||
rev: "22.8.0"
|
||||
hooks:
|
||||
- id: black
|
||||
language_version: python
|
11
.vscode/extensions.json
vendored
11
.vscode/extensions.json
vendored
|
@ -1,11 +0,0 @@
|
|||
{
|
||||
"recommendations": [
|
||||
"ms-vscode-remote.remote-containers",
|
||||
"ms-azuretools.vscode-docker",
|
||||
"editorconfig.editorconfig",
|
||||
"njpwerner.autodocstring",
|
||||
"ms-python.python",
|
||||
"ms-toolsai.jupyter",
|
||||
"eamodio.gitlens"
|
||||
]
|
||||
}
|
61
.vscode/launch.json
vendored
61
.vscode/launch.json
vendored
|
@ -1,29 +1,36 @@
|
|||
{
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Train",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/src/train.py",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
},
|
||||
{
|
||||
"name": "Predict",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/src/predict.py",
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false,
|
||||
"args": [
|
||||
"--input",
|
||||
"images/input.png",
|
||||
"--output",
|
||||
"images/output.png",
|
||||
"--model",
|
||||
"checkpoints/model.onnx"
|
||||
]
|
||||
}
|
||||
]
|
||||
// Use IntelliSense to learn about possible attributes.
|
||||
// Hover to view descriptions of existing attributes.
|
||||
// For more information, visit: https://go.microsoft.com/fwlink/?linkid=830387
|
||||
"version": "0.2.0",
|
||||
"configurations": [
|
||||
{
|
||||
"name": "Python: Current File",
|
||||
"type": "python",
|
||||
"request": "launch",
|
||||
"program": "${workspaceFolder}/src/main.py",
|
||||
// "program": "${workspaceFolder}/src/spheres.py",
|
||||
// "program": "${workspaceFolder}/src/datamodule.py",
|
||||
"args": [
|
||||
// "fit",
|
||||
"predict",
|
||||
// "--ckpt_path",
|
||||
// "${workspaceFolder}/lightning_logs/version_264/checkpoints/epoch=9-st&ep=1000.ckpt",
|
||||
"--data.num_workers",
|
||||
"0",
|
||||
"--trainer.benchmark",
|
||||
"false",
|
||||
"--trainer.num_sanity_val_steps",
|
||||
"0",
|
||||
"--data.persistent_workers",
|
||||
"false",
|
||||
"--data.batch_size",
|
||||
"1",
|
||||
"--trainer.val_check_interval",
|
||||
"1"
|
||||
],
|
||||
"console": "integratedTerminal",
|
||||
"justMyCode": false
|
||||
}
|
||||
]
|
||||
}
|
50
.vscode/settings.json
vendored
50
.vscode/settings.json
vendored
|
@ -1,27 +1,27 @@
|
|||
{
|
||||
"python.defaultInterpreterPath": ".venv/bin/python",
|
||||
"python.formatting.provider": "black",
|
||||
"editor.formatOnSave": true,
|
||||
"python.linting.enabled": true,
|
||||
"python.linting.lintOnSave": true,
|
||||
"python.linting.flake8Enabled": true,
|
||||
"python.linting.mypyEnabled": true,
|
||||
"python.linting.banditEnabled": true,
|
||||
"jupyter.debugJustMyCode": false,
|
||||
"[python]": {
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.organizeImports": true
|
||||
}
|
||||
},
|
||||
"files.insertFinalNewline": true,
|
||||
"files.exclude": {
|
||||
"**/.git": true,
|
||||
"**/.svn": true,
|
||||
"**/.hg": true,
|
||||
"**/CVS": true,
|
||||
"**/.DS_Store": true,
|
||||
"**/Thumbs.db": true,
|
||||
"**/__pycache__": true,
|
||||
"**/.mypy_cache": true,
|
||||
}
|
||||
// "python.defaultInterpreterPath": ".venv/bin/python",
|
||||
"python.analysis.typeCheckingMode": "off",
|
||||
"python.formatting.provider": "black",
|
||||
"editor.formatOnSave": true,
|
||||
"python.linting.enabled": true,
|
||||
"python.linting.lintOnSave": true,
|
||||
"python.linting.flake8Enabled": true,
|
||||
"python.linting.mypyEnabled": true,
|
||||
"python.linting.banditEnabled": true,
|
||||
"python.languageServer": "Pylance",
|
||||
"[python]": {
|
||||
"editor.codeActionsOnSave": {
|
||||
"source.organizeImports": true
|
||||
}
|
||||
},
|
||||
"files.exclude": {
|
||||
"**/.git": true,
|
||||
"**/.svn": true,
|
||||
"**/.hg": true,
|
||||
"**/CVS": true,
|
||||
"**/.DS_Store": true,
|
||||
"**/Thumbs.db": true,
|
||||
"**/__pycache__": true,
|
||||
"**/.mypy_cache": true,
|
||||
},
|
||||
}
|
8089
poetry.lock
generated
8089
poetry.lock
generated
File diff suppressed because it is too large
Load diff
|
@ -1,52 +1,55 @@
|
|||
[tool.poetry]
|
||||
authors = ["Laurent Fainsin <laurent@fainsin.bzh>"]
|
||||
description = "Simple neural network to detect calibration spheres in images."
|
||||
license = "MIT"
|
||||
name = "sphereDetect"
|
||||
readme = "README.md"
|
||||
version = "2.0.0"
|
||||
authors = ["Laurent Fainsin <laurentfainsin@protonmail.com>"]
|
||||
description = ""
|
||||
name = "label-studio"
|
||||
version = "1.0.0"
|
||||
|
||||
[tool.poetry.dependencies]
|
||||
albumentations = "^1.2.1"
|
||||
lightning-bolts = "^0.5.0"
|
||||
numpy = "^1.23.2"
|
||||
pycocotools = "^2.0.4"
|
||||
python = ">=3.8,<3.11"
|
||||
pytorch-lightning = "^1.7.4"
|
||||
rich = "^12.5.1"
|
||||
torch = "^1.12.1"
|
||||
torchmetrics = "^0.9.3"
|
||||
torchvision = "^0.13.1"
|
||||
wandb = "^0.13.2"
|
||||
datasets = "^2.9.0"
|
||||
fastapi = "0.86.0"
|
||||
jsonargparse = {extras = ["signatures"], version = "^4.20.0"}
|
||||
lightning = "1.9.1"
|
||||
matplotlib = "^3.7.0"
|
||||
numpy = "^1.24.2"
|
||||
opencv-python = "^4.7.0.72"
|
||||
opencv-python-headless = "^4.7.0.72"
|
||||
python = ">=3.8,<3.12"
|
||||
rich = "^13.3.1"
|
||||
scipy = "^1.10.0"
|
||||
timm = "^0.6.12"
|
||||
torch = "^1.13.1"
|
||||
transformers = "^4.26.1"
|
||||
|
||||
[tool.poetry.group.notebooks]
|
||||
optional = true
|
||||
|
||||
[tool.poetry.group.notebooks.dependencies]
|
||||
ipykernel = "^6.15.3"
|
||||
matplotlib = "^3.5.3"
|
||||
onnx = "^1.12.0"
|
||||
onnxruntime = "^1.12.1"
|
||||
onnxruntime-gpu = "^1.12.1"
|
||||
ipykernel = "^6.20.2"
|
||||
ipywidgets = "^8.0.4"
|
||||
jupyter = "^1.0.0"
|
||||
matplotlib = "^3.6.3"
|
||||
|
||||
[tool.poetry.group.dev.dependencies]
|
||||
Flake8-pyproject = "^1.1.0"
|
||||
bandit = "^1.7.4"
|
||||
black = {extras = ["jupyter"], version = "^22.8.0"}
|
||||
black = "^22.8.0"
|
||||
flake8 = "^5.0.4"
|
||||
flake8-docstrings = "^1.6.0"
|
||||
isort = "^5.10.1"
|
||||
mypy = "^0.971"
|
||||
pre-commit = "^2.20.0"
|
||||
tensorboard = "^2.12.0"
|
||||
torchtyping = "^0.1.4"
|
||||
torch-tb-profiler = "^0.4.1"
|
||||
|
||||
[build-system]
|
||||
build-backend = "poetry.core.masonry.api"
|
||||
requires = ["poetry-core>=1.0.0"]
|
||||
requires = ["poetry-core"]
|
||||
|
||||
[tool.flake8]
|
||||
# rules ignored
|
||||
extend-ignore = ["W503", "D401", "D403"]
|
||||
per-file-ignores = ["__init__.py:F401", "__init__.py:D104"]
|
||||
extend-ignore = ["W503", "D401", "D100", "D104"]
|
||||
per-file-ignores = ["__init__.py:F401"]
|
||||
# black
|
||||
ignore = "E203"
|
||||
max-line-length = 120
|
||||
|
|
311
src/comp.ipynb
311
src/comp.ipynb
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
|
@ -1 +0,0 @@
|
|||
from .dataloader import Spheres
|
|
@ -1,100 +0,0 @@
|
|||
"""Pytorch Lightning DataModules."""
|
||||
|
||||
import albumentations as A
|
||||
import pytorch_lightning as pl
|
||||
import wandb
|
||||
from albumentations.pytorch import ToTensorV2
|
||||
from torch.utils.data import DataLoader
|
||||
|
||||
from .dataset import LabeledDataset, RealDataset
|
||||
|
||||
|
||||
def collate_fn(batch):
|
||||
return tuple(zip(*batch))
|
||||
|
||||
|
||||
class Spheres(pl.LightningDataModule):
|
||||
"""Pytorch Lightning DataModule, encapsulating common PyTorch functions."""
|
||||
|
||||
def train_dataloader(self) -> DataLoader:
|
||||
"""PyTorch training Dataloader.
|
||||
|
||||
Returns:
|
||||
DataLoader: the training dataloader
|
||||
"""
|
||||
transforms = A.Compose(
|
||||
[
|
||||
# A.Flip(),
|
||||
# A.ColorJitter(),
|
||||
# A.ToGray(p=0.01),
|
||||
# A.GaussianBlur(),
|
||||
# A.MotionBlur(),
|
||||
# A.ISONoise(),
|
||||
# A.ImageCompression(),
|
||||
A.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225],
|
||||
max_pixel_value=255,
|
||||
), # [0, 255] -> coco (?) normalized
|
||||
ToTensorV2(), # HWC -> CHW
|
||||
],
|
||||
bbox_params=A.BboxParams(
|
||||
format="pascal_voc",
|
||||
min_area=0.0,
|
||||
min_visibility=0.0,
|
||||
label_fields=["labels"],
|
||||
),
|
||||
)
|
||||
|
||||
# dataset = LabeledDataset(image_dir="/dev/shm/TRAIN/", transforms=transforms)
|
||||
dataset = LabeledDataset(image_dir=wandb.config.DIR_TRAIN_IMG, transforms=transforms)
|
||||
# dataset = Subset(dataset, range(6 * 200)) # subset for debugging purpose
|
||||
# dataset = Subset(dataset, [0] * 320) # overfit test
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
shuffle=True,
|
||||
persistent_workers=True,
|
||||
prefetch_factor=wandb.config.PREFETCH_FACTOR,
|
||||
batch_size=wandb.config.TRAIN_BATCH_SIZE,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
collate_fn=collate_fn,
|
||||
)
|
||||
|
||||
def val_dataloader(self) -> DataLoader:
|
||||
"""PyTorch validation Dataloader.
|
||||
|
||||
Returns:
|
||||
DataLoader: the validation dataloader
|
||||
"""
|
||||
transforms = A.Compose(
|
||||
[
|
||||
A.Normalize(
|
||||
mean=[0.485, 0.456, 0.406],
|
||||
std=[0.229, 0.224, 0.225],
|
||||
max_pixel_value=255,
|
||||
), # [0, 255] -> [0.0, 1.0] normalized
|
||||
ToTensorV2(), # HWC -> CHW
|
||||
],
|
||||
bbox_params=A.BboxParams(
|
||||
format="pascal_voc",
|
||||
min_area=0.0,
|
||||
min_visibility=0.0,
|
||||
label_fields=["labels"],
|
||||
),
|
||||
)
|
||||
|
||||
# dataset = RealDataset(root="/dev/shm/TEST/", transforms=transforms)
|
||||
dataset = RealDataset(root=wandb.config.DIR_VALID_IMG, transforms=transforms)
|
||||
|
||||
return DataLoader(
|
||||
dataset,
|
||||
shuffle=False,
|
||||
persistent_workers=True,
|
||||
prefetch_factor=wandb.config.PREFETCH_FACTOR,
|
||||
batch_size=wandb.config.VALID_BATCH_SIZE,
|
||||
pin_memory=wandb.config.PIN_MEMORY,
|
||||
num_workers=wandb.config.WORKERS,
|
||||
collate_fn=collate_fn,
|
||||
)
|
|
@ -1,215 +0,0 @@
|
|||
"""Pytorch Datasets."""
|
||||
|
||||
import os
|
||||
from pathlib import Path
|
||||
|
||||
import albumentations as A
|
||||
import numpy as np
|
||||
import torch
|
||||
from PIL import Image
|
||||
from torch.utils.data import Dataset
|
||||
|
||||
|
||||
class SyntheticDataset(Dataset):
|
||||
def __init__(self, image_dir: str, transform: A.Compose) -> None:
|
||||
self.images = list(Path(image_dir).glob("**/*.jpg"))
|
||||
self.transform = transform
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, index: int):
|
||||
# open and convert image
|
||||
image = np.ascontiguousarray(
|
||||
Image.open(
|
||||
self.images[index],
|
||||
).convert("RGB"),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
# create empty mask of same size
|
||||
mask = np.zeros(
|
||||
(*image.shape[:2], 4),
|
||||
dtype=np.uint8,
|
||||
)
|
||||
|
||||
# augment image and mask
|
||||
augmentations = self.transform(image=image, mask=mask)
|
||||
image = augmentations["image"]
|
||||
mask = augmentations["mask"]
|
||||
|
||||
return image, mask
|
||||
|
||||
|
||||
class RealDataset(Dataset):
|
||||
def __init__(self, root, transforms=None) -> None:
|
||||
self.root = root
|
||||
self.transforms = transforms
|
||||
|
||||
# load all image files, sorting them to ensure that they are aligned
|
||||
self.imgs = list(sorted(os.listdir(os.path.join(root, "images"))))
|
||||
self.masks = list(sorted(os.listdir(os.path.join(root, "masks"))))
|
||||
|
||||
self.res = A.LongestMaxSize(max_size=1024)
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.imgs)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
# create paths from ids
|
||||
image_path = os.path.join(self.root, "images", self.imgs[idx])
|
||||
mask_path = os.path.join(self.root, "masks", self.masks[idx])
|
||||
|
||||
# load image and mask
|
||||
image = Image.open(image_path).convert("RGB")
|
||||
mask = Image.open(mask_path).convert("L")
|
||||
|
||||
# convert to numpy arrays
|
||||
image = np.ascontiguousarray(image)
|
||||
mask = np.ascontiguousarray(mask)
|
||||
|
||||
# resize images, TODO: remove ?
|
||||
aug = self.res(image=image, mask=mask)
|
||||
image = aug["image"]
|
||||
mask = aug["mask"]
|
||||
|
||||
# get ids from mask
|
||||
obj_ids = np.unique(mask)
|
||||
obj_ids = obj_ids[1:] # first id is the background, so remove it
|
||||
|
||||
# split the color-encoded mask into a set of binary masks
|
||||
masks = mask == obj_ids[:, None, None]
|
||||
masks = masks.astype(np.uint8) # cast to uint8 for albumentations
|
||||
|
||||
# create bboxes from masks (pascal format)
|
||||
num_objs = len(obj_ids)
|
||||
bboxes = []
|
||||
for i in range(num_objs):
|
||||
pos = np.where(masks[i])
|
||||
xmin = np.min(pos[1])
|
||||
xmax = np.max(pos[1])
|
||||
ymin = np.min(pos[0])
|
||||
ymax = np.max(pos[0])
|
||||
bboxes.append([xmin, ymin, xmax, ymax])
|
||||
|
||||
# convert arrays for albumentations
|
||||
bboxes = torch.as_tensor(bboxes, dtype=torch.int64)
|
||||
labels = torch.ones((num_objs,), dtype=torch.int64) # assume there is only one class (id=1)
|
||||
masks = list(np.asarray(masks))
|
||||
|
||||
if self.transforms is not None:
|
||||
# arrange transform data
|
||||
data = {
|
||||
"image": image,
|
||||
"labels": labels,
|
||||
"bboxes": bboxes,
|
||||
"masks": masks,
|
||||
}
|
||||
# apply transform
|
||||
augmented = self.transforms(**data)
|
||||
# get augmented data
|
||||
image = augmented["image"]
|
||||
bboxes = augmented["bboxes"]
|
||||
labels = augmented["labels"]
|
||||
masks = augmented["masks"]
|
||||
|
||||
bboxes = torch.as_tensor(bboxes, dtype=torch.int64)
|
||||
labels = torch.as_tensor(labels, dtype=torch.int64) # int64 required by torchvision maskrcnn
|
||||
masks = torch.stack(masks) # stack masks, required by torchvision maskrcnn
|
||||
|
||||
area = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
|
||||
image_id = torch.tensor([idx])
|
||||
iscrowd = torch.zeros((num_objs,), dtype=torch.int64) # assume all instances are not crowd
|
||||
|
||||
target = {
|
||||
"boxes": bboxes,
|
||||
"labels": labels,
|
||||
"masks": masks,
|
||||
"area": area,
|
||||
"image_id": image_id,
|
||||
"iscrowd": iscrowd,
|
||||
}
|
||||
|
||||
return image, target
|
||||
|
||||
|
||||
class LabeledDataset(Dataset):
|
||||
def __init__(self, image_dir, transforms) -> None:
|
||||
self.images = list(Path(image_dir).glob("**/*.jpg"))
|
||||
self.transforms = transforms
|
||||
|
||||
def __len__(self) -> int:
|
||||
return len(self.images)
|
||||
|
||||
def __getitem__(self, idx: int):
|
||||
# open and convert image
|
||||
image = np.ascontiguousarray(
|
||||
Image.open(self.images[idx]).convert("RGB"),
|
||||
)
|
||||
|
||||
# open and convert mask
|
||||
mask_path = self.images[idx].parent.joinpath("MASK.PNG")
|
||||
mask = np.ascontiguousarray(
|
||||
Image.open(mask_path).convert("L"),
|
||||
)
|
||||
|
||||
# get ids from mask
|
||||
obj_ids = np.unique(mask)
|
||||
obj_ids = obj_ids[1:] # first id is the background, so remove it
|
||||
|
||||
# split the color-encoded mask into a set of binary masks
|
||||
masks = mask == obj_ids[:, None, None]
|
||||
masks = masks.astype(np.uint8) # cast to uint8 for albumentations
|
||||
|
||||
# create bboxes from masks (pascal format)
|
||||
num_objs = len(obj_ids)
|
||||
bboxes = []
|
||||
labels = []
|
||||
for i in range(num_objs):
|
||||
pos = np.where(masks[i])
|
||||
xmin = np.min(pos[1])
|
||||
xmax = np.max(pos[1])
|
||||
ymin = np.min(pos[0])
|
||||
ymax = np.max(pos[0])
|
||||
bboxes.append([xmin, ymin, xmax, ymax])
|
||||
labels.append(2 if mask[(ymax + ymin) // 2, (xmax + xmin) // 2] > 127 else 1)
|
||||
|
||||
# convert arrays for albumentations
|
||||
bboxes = torch.as_tensor(bboxes, dtype=torch.int64)
|
||||
labels = torch.as_tensor(labels, dtype=torch.int64)
|
||||
masks = list(np.asarray(masks))
|
||||
|
||||
if self.transforms is not None:
|
||||
# arrange transform data
|
||||
data = {
|
||||
"image": image,
|
||||
"labels": labels,
|
||||
"bboxes": bboxes,
|
||||
"masks": masks,
|
||||
}
|
||||
# apply transform
|
||||
augmented = self.transforms(**data)
|
||||
# get augmented data
|
||||
image = augmented["image"]
|
||||
bboxes = augmented["bboxes"]
|
||||
labels = augmented["labels"]
|
||||
masks = augmented["masks"]
|
||||
|
||||
bboxes = torch.as_tensor(bboxes, dtype=torch.int64)
|
||||
labels = torch.as_tensor(labels, dtype=torch.int64) # int64 required by torchvision maskrcnn
|
||||
masks = torch.stack(masks) # stack masks, required by torchvision maskrcnn
|
||||
|
||||
area = (bboxes[:, 3] - bboxes[:, 1]) * (bboxes[:, 2] - bboxes[:, 0])
|
||||
image_id = torch.tensor([idx])
|
||||
iscrowd = torch.zeros((num_objs,), dtype=torch.int64) # assume all instances are not crowd
|
||||
|
||||
target = {
|
||||
"boxes": bboxes,
|
||||
"labels": labels,
|
||||
"masks": masks,
|
||||
"area": area,
|
||||
"image_id": image_id,
|
||||
"iscrowd": iscrowd,
|
||||
}
|
||||
|
||||
return image, target
|
277
src/datamodule.py
Normal file
277
src/datamodule.py
Normal file
|
@ -0,0 +1,277 @@
|
|||
import datasets
|
||||
import torch
|
||||
from lightning.pytorch import LightningDataModule
|
||||
from lightning.pytorch.trainer.supporters import CombinedLoader
|
||||
from torch.utils.data import DataLoader
|
||||
from torchvision.datasets import ImageFolder
|
||||
from torchvision.transforms import AugMix
|
||||
from transformers import DetrFeatureExtractor
|
||||
|
||||
|
||||
class DETRDataModule(LightningDataModule):
|
||||
"""PyTorch Lightning data module for DETR."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
num_workers: int = 8,
|
||||
batch_size: int = 6,
|
||||
prefetch_factor: int = 2,
|
||||
model_name: str = "facebook/detr-resnet-50",
|
||||
persistent_workers: bool = True,
|
||||
):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
num_workers (int, optional): Number of workers.
|
||||
batch_size (int, optional): Batch size.
|
||||
prefetch_factor (int, optional): Prefetch factor.
|
||||
val_split (float, optional): Validation split.
|
||||
model_name (str, optional): Model name.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# save params
|
||||
self.num_workers = num_workers
|
||||
self.batch_size = batch_size
|
||||
self.prefetch_factor = prefetch_factor
|
||||
self.persistent_workers = persistent_workers
|
||||
|
||||
# get feature extractor
|
||||
self.feature_extractor = DetrFeatureExtractor.from_pretrained(model_name)
|
||||
|
||||
def prepare_data(self):
|
||||
"""Download data and prepare for training."""
|
||||
# load datasets
|
||||
self.illumination = datasets.load_dataset("src/spheres_illumination.py", split="train")
|
||||
self.render = datasets.load_dataset("src/spheres_synth.py", split="train")
|
||||
self.real = datasets.load_dataset("src/spheres.py", split="train")
|
||||
|
||||
# split datasets
|
||||
self.illumination = self.illumination.train_test_split(test_size=0.01)
|
||||
self.render = self.render.train_test_split(test_size=0.01)
|
||||
self.real = self.real.train_test_split(test_size=0.1)
|
||||
|
||||
# print some info
|
||||
print(f"illumination: {self.illumination}")
|
||||
print(f"render: {self.render}")
|
||||
print(f"real: {self.real}")
|
||||
|
||||
# other datasets
|
||||
self.test_ds = datasets.load_dataset("src/spheres_illumination.py", split="test")
|
||||
# self.predict_ds = datasets.load_dataset("src/spheres.py", split="train").shuffle().select(range(16))
|
||||
self.predict_ds = datasets.load_dataset("src/spheres_predict.py", split="train")
|
||||
|
||||
# define AugMix transform
|
||||
self.mix = AugMix()
|
||||
|
||||
# useful mappings
|
||||
self.labels = self.real["test"].features["objects"][0]["category_id"].names
|
||||
self.id2label = {k: v for k, v in enumerate(self.labels)}
|
||||
self.label2id = {v: k for k, v in enumerate(self.labels)}
|
||||
|
||||
def train_transform(self, batch):
|
||||
"""Training transform.
|
||||
|
||||
Args:
|
||||
batch (dict): Batch precollated by HuggingFace datasets.
|
||||
Structure is similar to the following:
|
||||
{
|
||||
"image": list[PIL.Image],
|
||||
"image_id": list[int],
|
||||
"objects": [
|
||||
{
|
||||
"bbox": list[float, 4],
|
||||
"category_id": int,
|
||||
}
|
||||
]
|
||||
}
|
||||
|
||||
Returns:
|
||||
dict: Augmented and processed batch.
|
||||
Structure is similar to the following:
|
||||
{
|
||||
"pixel_values": TensorType["batch", "canal", "width", "height"],
|
||||
"pixel_mask": TensorType["batch", 1200, 1200],
|
||||
"labels": List[Dict[str, TensorType["batch", "num_boxes", "num_labels"]]],
|
||||
}
|
||||
"""
|
||||
# extract images, ids and objects from batch
|
||||
images = batch["image"]
|
||||
ids = batch["image_id"]
|
||||
objects = batch["objects"]
|
||||
|
||||
# apply AugMix transform
|
||||
images_mixed = [self.mix(image) for image in images]
|
||||
|
||||
# build targets for feature extractor
|
||||
targets = [
|
||||
{
|
||||
"image_id": id,
|
||||
"annotations": object,
|
||||
}
|
||||
for id, object in zip(ids, objects)
|
||||
]
|
||||
|
||||
# process images and targets with feature extractor for DETR
|
||||
processed = self.feature_extractor(
|
||||
images=images_mixed,
|
||||
annotations=targets,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
return processed
|
||||
|
||||
def val_transform(self, batch):
|
||||
"""Validation transform.
|
||||
|
||||
Just like Training transform, but without AugMix.
|
||||
"""
|
||||
# extract images, ids and objects from batch
|
||||
images = batch["image"]
|
||||
ids = batch["image_id"]
|
||||
objects = batch["objects"]
|
||||
|
||||
# build targets for feature extractor
|
||||
targets = [
|
||||
{
|
||||
"image_id": id,
|
||||
"annotations": object,
|
||||
}
|
||||
for id, object in zip(ids, objects)
|
||||
]
|
||||
|
||||
processed = self.feature_extractor(
|
||||
images=images,
|
||||
annotations=targets,
|
||||
return_tensors="pt",
|
||||
)
|
||||
|
||||
return processed
|
||||
|
||||
def predict_transform(self, batch):
|
||||
"""Prediction transform.
|
||||
|
||||
Just like val_transform, but with images.
|
||||
"""
|
||||
processed = self.val_transform(batch)
|
||||
|
||||
# add images to dict
|
||||
processed["images"] = batch["image"]
|
||||
|
||||
return processed
|
||||
|
||||
def collate_fn(self, examples):
|
||||
"""Collate function.
|
||||
|
||||
Convert list of dicts to dict of Tensors.
|
||||
"""
|
||||
return {
|
||||
"pixel_values": torch.stack([data["pixel_values"] for data in examples]),
|
||||
"pixel_mask": torch.stack([data["pixel_mask"] for data in examples]),
|
||||
"labels": [data["labels"] for data in examples],
|
||||
}
|
||||
|
||||
def collate_fn_predict(self, examples):
|
||||
"""Collate function.
|
||||
|
||||
Convert list of dicts to dict of Tensors.
|
||||
"""
|
||||
return {
|
||||
"pixel_values": torch.stack([data["pixel_values"] for data in examples]),
|
||||
"pixel_mask": torch.stack([data["pixel_mask"] for data in examples]),
|
||||
"labels": [data["labels"] for data in examples],
|
||||
"images": [data["images"] for data in examples],
|
||||
}
|
||||
|
||||
def train_dataloader(self):
|
||||
"""Training dataloader."""
|
||||
loaders = {
|
||||
"illumination": DataLoader(
|
||||
self.illumination["train"].with_transform(self.val_transform),
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
persistent_workers=self.persistent_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
prefetch_factor=self.prefetch_factor,
|
||||
),
|
||||
"render": DataLoader(
|
||||
self.render["train"].with_transform(self.val_transform),
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
persistent_workers=self.persistent_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
prefetch_factor=self.prefetch_factor,
|
||||
),
|
||||
"real": DataLoader(
|
||||
self.real["train"].with_transform(self.val_transform),
|
||||
shuffle=True,
|
||||
pin_memory=True,
|
||||
persistent_workers=self.persistent_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
prefetch_factor=self.prefetch_factor,
|
||||
),
|
||||
}
|
||||
|
||||
return CombinedLoader(loaders, mode="max_size_cycle")
|
||||
|
||||
def val_dataloader(self):
|
||||
"""Validation dataloader."""
|
||||
loaders = {
|
||||
"illumination": DataLoader(
|
||||
self.illumination["test"].with_transform(self.val_transform),
|
||||
pin_memory=True,
|
||||
persistent_workers=self.persistent_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
prefetch_factor=self.prefetch_factor,
|
||||
),
|
||||
"render": DataLoader(
|
||||
self.render["test"].with_transform(self.val_transform),
|
||||
pin_memory=True,
|
||||
persistent_workers=self.persistent_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
prefetch_factor=self.prefetch_factor,
|
||||
),
|
||||
"real": DataLoader(
|
||||
self.real["test"].with_transform(self.val_transform),
|
||||
pin_memory=True,
|
||||
persistent_workers=self.persistent_workers,
|
||||
collate_fn=self.collate_fn,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
prefetch_factor=self.prefetch_factor,
|
||||
),
|
||||
}
|
||||
|
||||
return CombinedLoader(loaders, mode="max_size_cycle")
|
||||
|
||||
def predict_dataloader(self):
|
||||
"""Prediction dataloader."""
|
||||
return DataLoader(
|
||||
self.predict_ds.with_transform(self.predict_transform),
|
||||
pin_memory=True,
|
||||
persistent_workers=self.persistent_workers,
|
||||
collate_fn=self.collate_fn_predict,
|
||||
batch_size=self.batch_size,
|
||||
num_workers=self.num_workers,
|
||||
prefetch_factor=self.prefetch_factor,
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# load data
|
||||
dm = DETRDataModule()
|
||||
dm.prepare_data()
|
||||
ds = dm.train_dataloader()
|
||||
|
||||
for batch in ds:
|
||||
print(batch)
|
File diff suppressed because one or more lines are too long
44
src/main.py
Normal file
44
src/main.py
Normal file
|
@ -0,0 +1,44 @@
|
|||
from lightning.pytorch.callbacks import (
|
||||
ModelCheckpoint,
|
||||
RichModelSummary,
|
||||
RichProgressBar,
|
||||
)
|
||||
from lightning.pytorch.cli import LightningCLI
|
||||
|
||||
from datamodule import DETRDataModule
|
||||
from module import DETR
|
||||
|
||||
|
||||
class MyLightningCLI(LightningCLI):
|
||||
"""Custom Lightning CLI to define default arguments."""
|
||||
|
||||
def add_arguments_to_parser(self, parser):
|
||||
"""Add arguments to parser."""
|
||||
parser.set_defaults(
|
||||
{
|
||||
"trainer.multiple_trainloader_mode": "max_size_cycle",
|
||||
"trainer.max_steps": 5000,
|
||||
"trainer.max_epochs": 1,
|
||||
"trainer.accelerator": "gpu",
|
||||
"trainer.devices": "[1]",
|
||||
"trainer.strategy": "dp",
|
||||
"trainer.log_every_n_steps": 25,
|
||||
"trainer.val_check_interval": 200,
|
||||
"trainer.num_sanity_val_steps": 10,
|
||||
"trainer.benchmark": True,
|
||||
"trainer.callbacks": [
|
||||
RichProgressBar(),
|
||||
RichModelSummary(max_depth=2),
|
||||
ModelCheckpoint(mode="min", monitor="val_loss_real"),
|
||||
ModelCheckpoint(save_on_train_epoch_end=True),
|
||||
],
|
||||
}
|
||||
)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
cli = MyLightningCLI(
|
||||
model_class=DETR,
|
||||
datamodule_class=DETRDataModule,
|
||||
seed_everything_default=69420,
|
||||
)
|
191
src/module.py
Normal file
191
src/module.py
Normal file
|
@ -0,0 +1,191 @@
|
|||
import torch
|
||||
from lightning.pytorch import LightningModule
|
||||
from PIL import ImageDraw
|
||||
from transformers import (
|
||||
DetrForObjectDetection,
|
||||
get_cosine_with_hard_restarts_schedule_with_warmup,
|
||||
)
|
||||
|
||||
|
||||
class DETR(LightningModule):
|
||||
"""PyTorch Lightning module for DETR."""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
lr: float = 1e-4,
|
||||
lr_backbone: float = 1e-5,
|
||||
weight_decay: float = 1e-4,
|
||||
num_queries: int = 100,
|
||||
warmup_steps: int = 0,
|
||||
num_labels: int = 3,
|
||||
prediction_threshold: float = 0.9,
|
||||
):
|
||||
"""Constructor.
|
||||
|
||||
Args:
|
||||
lr (float, optional): Learning rate.
|
||||
lr_backbone (float, optional): Learning rate for backbone.
|
||||
weight_decay (float, optional): Weight decay.
|
||||
num_queries (int, optional): Number of queries.
|
||||
warmup_steps (int, optional): Number of warmup steps.
|
||||
num_labels (int, optional): Number of labels.
|
||||
prediction_threshold (float, optional): Prediction threshold.
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# replace COCO classification head with custom head
|
||||
self.net = DetrForObjectDetection.from_pretrained(
|
||||
"facebook/detr-resnet-50",
|
||||
ignore_mismatched_sizes=True,
|
||||
num_queries=num_queries,
|
||||
num_labels=num_labels,
|
||||
)
|
||||
|
||||
# cf https://github.com/PyTorchLightning/pytorch-lightning/pull/1896
|
||||
self.lr = lr
|
||||
self.lr_backbone = lr_backbone
|
||||
self.weight_decay = weight_decay
|
||||
self.warmup_steps = warmup_steps
|
||||
self.prediction_threshold = prediction_threshold
|
||||
|
||||
self.save_hyperparameters()
|
||||
|
||||
def forward(self, pixel_values, pixel_mask, **kwargs):
|
||||
"""Forward pass."""
|
||||
return self.net(
|
||||
pixel_values=pixel_values,
|
||||
pixel_mask=pixel_mask,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
def common_step(self, batchs, batch_idx):
|
||||
"""Common step for training and validation.
|
||||
|
||||
Args:
|
||||
batch (dict): Batch from dataloader (after collate_fn).
|
||||
Structure is similar to the following:
|
||||
{
|
||||
"pixel_values": TensorType["batch", "canal", "width", "height"],
|
||||
"pixel_mask": TensorType["batch", 1200, 1200],
|
||||
"labels": List[Dict[str, TensorType["batch", "num_boxes", "num_labels"]]], # TODO: check this type
|
||||
}
|
||||
|
||||
batch_idx (int): Batch index.
|
||||
|
||||
Returns:
|
||||
tuple: Loss and loss dict.
|
||||
"""
|
||||
# intialize outputs
|
||||
outputs = {k: {"loss": None, "loss_dict": None} for k in batchs.keys()}
|
||||
|
||||
# for each dataloader
|
||||
for dataloader_name, batch in batchs.items():
|
||||
# extract pixel_values, pixel_mask and labels from batch
|
||||
pixel_values = batch["pixel_values"]
|
||||
pixel_mask = batch["pixel_mask"]
|
||||
labels = [{k: v.to(self.device) for k, v in t.items()} for t in batch["labels"]]
|
||||
|
||||
# forward pass
|
||||
model_output = self(pixel_values=pixel_values, pixel_mask=pixel_mask, labels=labels)
|
||||
|
||||
# get loss
|
||||
outputs[dataloader_name] = {
|
||||
"loss": model_output.loss,
|
||||
"loss_dict": model_output.loss_dict,
|
||||
}
|
||||
|
||||
return outputs
|
||||
|
||||
def training_step(self, batch, batch_idx):
|
||||
"""Training step."""
|
||||
outputs = self.common_step(batch, batch_idx)
|
||||
|
||||
# logs metrics for each training_step and the average across the epoch
|
||||
loss = 0
|
||||
for dataloader_name, output in outputs.items():
|
||||
loss += output["loss"]
|
||||
self.log(f"train_loss_{dataloader_name}", output["loss"])
|
||||
for k, v in output["loss_dict"].items():
|
||||
self.log(f"train_loss_{k}_{dataloader_name}", v.item())
|
||||
|
||||
self.log("lr", self.optimizers().param_groups[0]["lr"])
|
||||
self.log("lr_backbone", self.optimizers().param_groups[1]["lr"])
|
||||
|
||||
return loss
|
||||
|
||||
def validation_step(self, batch, batch_idx, dataloader_idx=None):
|
||||
"""Validation step."""
|
||||
outputs = self.common_step(batch, batch_idx)
|
||||
|
||||
# logs metrics for each validation_step and the average across the epoch
|
||||
loss = 0
|
||||
for dataloader_name, output in outputs.items():
|
||||
loss += output["loss"]
|
||||
self.log(f"val_loss_{dataloader_name}", output["loss"])
|
||||
for k, v in output["loss_dict"].items():
|
||||
self.log(f"val_loss_{k}_{dataloader_name}", v.item())
|
||||
|
||||
return loss
|
||||
|
||||
def predict_step(self, batch, batch_idx, dataloader_idx=None):
|
||||
"""Predict step."""
|
||||
# extract pixel_values and pixelmask from batch
|
||||
pixel_values = batch["pixel_values"]
|
||||
pixel_mask = batch["pixel_mask"]
|
||||
images = batch["images"]
|
||||
|
||||
from transformers import AutoImageProcessor
|
||||
|
||||
image_processor = AutoImageProcessor.from_pretrained("facebook/detr-resnet-50")
|
||||
|
||||
# forward pass
|
||||
outputs = self(pixel_values=pixel_values, pixel_mask=pixel_mask)
|
||||
|
||||
# postprocess outputs
|
||||
sizes = torch.tensor([image.size[::-1] for image in images], device=self.device)
|
||||
processed_outputs = image_processor.post_process_object_detection(
|
||||
outputs, threshold=self.prediction_threshold, target_sizes=sizes
|
||||
)
|
||||
|
||||
for i, image in enumerate(images):
|
||||
# create ImageDraw object to draw on image
|
||||
draw = ImageDraw.Draw(image)
|
||||
|
||||
# draw predicted bboxes
|
||||
for bbox, label, score in zip(
|
||||
processed_outputs[i]["boxes"].cpu().detach().numpy(),
|
||||
processed_outputs[i]["labels"].cpu().detach().numpy(),
|
||||
processed_outputs[i]["scores"].cpu().detach().numpy(),
|
||||
):
|
||||
if label == 0:
|
||||
outline = "red"
|
||||
elif label == 1:
|
||||
outline = "blue"
|
||||
else:
|
||||
outline = "green"
|
||||
draw.rectangle(bbox, outline=outline, width=5)
|
||||
draw.text((bbox[0], bbox[1]), f"{score:0.4f}", fill="black", width=15)
|
||||
|
||||
# save image to image.png using PIL
|
||||
image.save(f"image2_{batch_idx}_{i}.jpg")
|
||||
|
||||
def configure_optimizers(self):
|
||||
"""Configure optimizers."""
|
||||
param_dicts = [
|
||||
{
|
||||
"params": [p for n, p in self.named_parameters() if "backbone" not in n and p.requires_grad],
|
||||
},
|
||||
{
|
||||
"params": [p for n, p in self.named_parameters() if "backbone" in n and p.requires_grad],
|
||||
"lr": self.lr_backbone,
|
||||
},
|
||||
]
|
||||
optimizer = torch.optim.AdamW(param_dicts, lr=self.lr, weight_decay=self.weight_decay)
|
||||
|
||||
scheduler = get_cosine_with_hard_restarts_schedule_with_warmup(
|
||||
optimizer,
|
||||
num_warmup_steps=self.warmup_steps,
|
||||
num_training_steps=self.trainer.estimated_stepping_batches,
|
||||
)
|
||||
|
||||
return [optimizer], [{"scheduler": scheduler, "interval": "step"}]
|
|
@ -1 +0,0 @@
|
|||
from .mrcnn import MRCNNModule
|
|
@ -1,193 +0,0 @@
|
|||
"""Mask R-CNN Pytorch Lightning Module for Object Detection and Segmentation."""
|
||||
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import torch
|
||||
import torchvision
|
||||
import wandb
|
||||
from torchmetrics.detection.mean_ap import MeanAveragePrecision
|
||||
from torchvision.models.detection.faster_rcnn import FastRCNNPredictor
|
||||
from torchvision.models.detection.mask_rcnn import (
|
||||
MaskRCNN,
|
||||
MaskRCNN_ResNet50_FPN_Weights,
|
||||
MaskRCNNPredictor,
|
||||
)
|
||||
|
||||
Prediction = List[Dict[str, torch.Tensor]]
|
||||
|
||||
|
||||
def get_model_instance_segmentation(n_classes: int) -> MaskRCNN:
|
||||
"""Returns a Torchvision MaskRCNN model for finetunning.
|
||||
|
||||
Args:
|
||||
n_classes (int): number of classes the model should predict, background included
|
||||
|
||||
Returns:
|
||||
MaskRCNN: the model ready to be used
|
||||
"""
|
||||
# load an instance segmentation model pre-trained on COCO
|
||||
model = torchvision.models.detection.maskrcnn_resnet50_fpn(
|
||||
weights=MaskRCNN_ResNet50_FPN_Weights.DEFAULT,
|
||||
box_detections_per_img=10, # cap numbers of detections, else memory explosion
|
||||
)
|
||||
|
||||
# get number of input features for the classifier
|
||||
in_features = model.roi_heads.box_predictor.cls_score.in_features
|
||||
# replace the pre-trained head with a new one
|
||||
model.roi_heads.box_predictor = FastRCNNPredictor(in_features, n_classes)
|
||||
|
||||
# now get the number of input features for the mask classifier
|
||||
in_features_mask = model.roi_heads.mask_predictor.conv5_mask.in_channels
|
||||
hidden_layer = 256
|
||||
# and replace the mask predictor with a new one
|
||||
model.roi_heads.mask_predictor = MaskRCNNPredictor(in_features_mask, hidden_layer, n_classes)
|
||||
|
||||
return model
|
||||
|
||||
|
||||
class MRCNNModule(pl.LightningModule):
|
||||
"""Mask R-CNN Pytorch Lightning Module, encapsulating common PyTorch functions."""
|
||||
|
||||
def __init__(self, n_classes: int) -> None:
|
||||
"""Constructor, build model, save hyperparameters.
|
||||
|
||||
Args:
|
||||
n_classes (int): number of classes the model should predict, background included
|
||||
"""
|
||||
super().__init__()
|
||||
|
||||
# Hyperparameters
|
||||
self.n_classes = n_classes
|
||||
|
||||
# log hyperparameters
|
||||
self.save_hyperparameters()
|
||||
|
||||
# Network
|
||||
self.model = get_model_instance_segmentation(n_classes)
|
||||
|
||||
# onnx export
|
||||
# self.example_input_array = torch.randn(1, 3, 1024, 1024, requires_grad=True).half()
|
||||
|
||||
# torchmetrics
|
||||
self.metric_bbox = MeanAveragePrecision(iou_type="bbox")
|
||||
self.metric_segm = MeanAveragePrecision(iou_type="segm")
|
||||
|
||||
# def forward(self, imgs: torch.Tensor) -> Prediction: # type: ignore
|
||||
# """Make a forward pass (prediction), usefull for onnx export.
|
||||
|
||||
# Args:
|
||||
# imgs (torch.Tensor): the images whose prediction we wish to make
|
||||
|
||||
# Returns:
|
||||
# torch.Tensor: the predictions
|
||||
# """
|
||||
# self.model.eval()
|
||||
# pred: Prediction = self.model(imgs)
|
||||
# return pred
|
||||
|
||||
def training_step(self, batch: torch.Tensor, batch_idx: int) -> float: # type: ignore
|
||||
"""PyTorch training step.
|
||||
|
||||
Args:
|
||||
batch (torch.Tensor): the batch to train the model on
|
||||
batch_idx (int): the batch index number
|
||||
|
||||
Returns:
|
||||
float: the training loss of this step
|
||||
"""
|
||||
# unpack batch
|
||||
images, targets = batch
|
||||
|
||||
# compute loss
|
||||
loss_dict: dict[str, float] = self.model(images, targets)
|
||||
loss_dict = {f"train/{key}": val for key, val in loss_dict.items()}
|
||||
loss = sum(loss_dict.values())
|
||||
loss_dict["train/loss"] = loss
|
||||
|
||||
# log everything
|
||||
self.log_dict(loss_dict)
|
||||
|
||||
return loss
|
||||
|
||||
def on_validation_epoch_start(self) -> None:
|
||||
"""Reset TorchMetrics."""
|
||||
self.metric_bbox.reset()
|
||||
self.metric_segm.reset()
|
||||
|
||||
def validation_step(self, batch: torch.Tensor, batch_idx: int) -> Prediction: # type: ignore
|
||||
"""PyTorch validation step.
|
||||
|
||||
Args:
|
||||
batch (torch.Tensor): the batch to evaluate the model on
|
||||
batch_idx (int): the batch index number
|
||||
|
||||
Returns:
|
||||
torch.Tensor: the predictions
|
||||
"""
|
||||
# unpack batch
|
||||
images, targets = batch
|
||||
|
||||
# make prediction
|
||||
preds: Prediction = self.model(images)
|
||||
|
||||
# update TorchMetrics from predictions
|
||||
for pred, target in zip(preds, targets):
|
||||
pred["masks"] = pred["masks"].squeeze(1).int().bool()
|
||||
target["masks"] = target["masks"].squeeze(1).int().bool()
|
||||
self.metric_bbox.update(preds, targets)
|
||||
self.metric_segm.update(preds, targets)
|
||||
|
||||
return preds
|
||||
|
||||
def validation_epoch_end(self, outputs: List[Prediction]) -> None: # type: ignore
|
||||
"""Compute TorchMetrics.
|
||||
|
||||
Args:
|
||||
outputs (List[Prediction]): list of predictions from validation steps
|
||||
"""
|
||||
# compute metrics
|
||||
metric_dict_bbox = self.metric_bbox.compute()
|
||||
metric_dict_segm = self.metric_segm.compute()
|
||||
metric_dict_sum = {
|
||||
f"valid/sum/{k}": metric_dict_bbox.get(k, 0) + metric_dict_segm.get(k, 0)
|
||||
for k in set(metric_dict_bbox) & set(metric_dict_segm)
|
||||
}
|
||||
|
||||
# change metrics keys
|
||||
metric_dict_bbox = {f"valid/bbox/{key}": val for key, val in metric_dict_bbox.items()}
|
||||
metric_dict_segm = {f"valid/segm/{key}": val for key, val in metric_dict_segm.items()}
|
||||
|
||||
# log metrics
|
||||
self.log_dict(metric_dict_bbox)
|
||||
self.log_dict(metric_dict_segm)
|
||||
self.log_dict(metric_dict_sum)
|
||||
|
||||
def configure_optimizers(self) -> Dict[str, Any]:
|
||||
"""PyTorch optimizers and Schedulers.
|
||||
|
||||
Returns:
|
||||
Dict[str, Any]: dictionnary for PyTorch Lightning optimizer/scheduler configuration
|
||||
"""
|
||||
optimizer = torch.optim.Adam(
|
||||
self.parameters(),
|
||||
lr=wandb.config.LEARNING_RATE,
|
||||
# momentum=wandb.config.MOMENTUM,
|
||||
# weight_decay=wandb.config.WEIGHT_DECAY,
|
||||
)
|
||||
|
||||
# scheduler = LinearWarmupCosineAnnealingLR(
|
||||
# optimizer,
|
||||
# warmup_epochs=1,
|
||||
# max_epochs=30,
|
||||
# )
|
||||
|
||||
return {
|
||||
"optimizer": optimizer,
|
||||
# "lr_scheduler": {
|
||||
# "scheduler": scheduler,
|
||||
# "interval": "step",
|
||||
# "frequency": 10,
|
||||
# "monitor": "bbox/map",
|
||||
# },
|
||||
}
|
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
File diff suppressed because one or more lines are too long
242
src/spheres.py
Normal file
242
src/spheres.py
Normal file
|
@ -0,0 +1,242 @@
|
|||
"""Dataset class AI or NOT HuggingFace competition."""
|
||||
|
||||
import json
|
||||
import pathlib
|
||||
|
||||
import cv2
|
||||
import datasets
|
||||
import numpy as np
|
||||
|
||||
prefix = "/data/local-files/?d=spheres/"
|
||||
dataset_path = pathlib.Path("./dataset3/spheres/")
|
||||
annotation_path = pathlib.Path("./annotations2.json")
|
||||
|
||||
_VERSION = "1.0.0"
|
||||
|
||||
_DESCRIPTION = ""
|
||||
|
||||
_HOMEPAGE = ""
|
||||
|
||||
_LICENSE = ""
|
||||
|
||||
_NAMES = [
|
||||
# "White",
|
||||
# "Black",
|
||||
# "Grey",
|
||||
# "Red",
|
||||
# "Chrome",
|
||||
"Matte",
|
||||
"Shiny",
|
||||
"Chrome",
|
||||
]
|
||||
|
||||
|
||||
class spheres(datasets.GeneratorBasedBuilder):
|
||||
"""spheres image dataset."""
|
||||
|
||||
def _info(self):
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
version=_VERSION,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
features=datasets.Features(
|
||||
{
|
||||
"image_id": datasets.Value("int64"),
|
||||
"image": datasets.Image(),
|
||||
"width": datasets.Value("int32"),
|
||||
"height": datasets.Value("int32"),
|
||||
"objects": [
|
||||
{
|
||||
"category_id": datasets.ClassLabel(names=_NAMES),
|
||||
"image_id": datasets.Value("int64"),
|
||||
"id": datasets.Value("string"),
|
||||
"area": datasets.Value("float32"),
|
||||
"bbox": datasets.Sequence(datasets.Value("float32"), length=4),
|
||||
"iscrowd": datasets.Value("bool"),
|
||||
}
|
||||
],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager):
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"dataset_path": dataset_path,
|
||||
"annotation_path": annotation_path,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
def _generate_examples(self, dataset_path: pathlib.Path, annotation_path: pathlib.Path):
|
||||
"""Generate images and labels for splits."""
|
||||
with open(annotation_path, "r") as f:
|
||||
tasks = json.load(f)
|
||||
index = 0
|
||||
|
||||
for task in tasks:
|
||||
image_id = task["id"]
|
||||
image_name = task["data"]["img"]
|
||||
image_name = image_name[len(prefix) :]
|
||||
image_name = pathlib.Path(image_name)
|
||||
|
||||
# skip shitty images
|
||||
# if "Soulages" in str(image_name):
|
||||
# continue
|
||||
|
||||
# check image_name exists
|
||||
assert (dataset_path / image_name).is_file()
|
||||
|
||||
# create annotation groups
|
||||
annotation_groups: dict[str, list[dict]] = {}
|
||||
for annotation in task["annotations"][0]["result"]:
|
||||
id = annotation["id"]
|
||||
if "parentID" in annotation:
|
||||
parent_id = annotation["parentID"]
|
||||
if parent_id not in annotation_groups:
|
||||
annotation_groups[parent_id] = []
|
||||
annotation_groups[parent_id].append(annotation)
|
||||
else:
|
||||
if id not in annotation_groups:
|
||||
annotation_groups[id] = []
|
||||
annotation_groups[id].append(annotation)
|
||||
|
||||
# check all annotations have same width and height
|
||||
width = task["annotations"][0]["result"][0]["original_width"]
|
||||
height = task["annotations"][0]["result"][0]["original_height"]
|
||||
for annotation in task["annotations"][0]["result"]:
|
||||
assert annotation["original_width"] == width
|
||||
assert annotation["original_height"] == height
|
||||
|
||||
# check all childs of group have same label
|
||||
labels = {}
|
||||
for group_id, annotations in annotation_groups.items():
|
||||
label = annotations[0]["value"]["keypointlabels"][0]
|
||||
for annotation in annotations:
|
||||
assert annotation["value"]["keypointlabels"][0] == label
|
||||
|
||||
if label == "White":
|
||||
label = "Matte"
|
||||
elif label == "Black":
|
||||
label = "Shiny"
|
||||
elif label == "Red":
|
||||
label = "Shiny"
|
||||
|
||||
labels[group_id] = label
|
||||
|
||||
# compute bboxes
|
||||
bboxes = {}
|
||||
for group_id, annotations in annotation_groups.items():
|
||||
# convert points to numpy array
|
||||
points = np.array(
|
||||
[
|
||||
[
|
||||
annotation["value"]["x"] / 100 * width,
|
||||
annotation["value"]["y"] / 100 * height,
|
||||
]
|
||||
for annotation in annotations
|
||||
],
|
||||
dtype=np.float32,
|
||||
)
|
||||
|
||||
# fit ellipse from points
|
||||
ellipse = cv2.fitEllipse(points)
|
||||
|
||||
# extract ellipse parameters
|
||||
x_C = ellipse[0][0]
|
||||
y_C = ellipse[0][1]
|
||||
a = ellipse[1][0] / 2
|
||||
b = ellipse[1][1] / 2
|
||||
theta = ellipse[2] * np.pi / 180
|
||||
|
||||
# sample ellipse points
|
||||
t = np.linspace(0, 2 * np.pi, 100)
|
||||
x = x_C + a * np.cos(t) * np.cos(theta) - b * np.sin(t) * np.sin(theta)
|
||||
y = y_C + a * np.cos(t) * np.sin(theta) + b * np.sin(t) * np.cos(theta)
|
||||
|
||||
# get bounding box
|
||||
xmin = np.min(x)
|
||||
xmax = np.max(x)
|
||||
ymin = np.min(y)
|
||||
ymax = np.max(y)
|
||||
|
||||
w = xmax - xmin
|
||||
h = ymax - ymin
|
||||
|
||||
# bboxe to coco format
|
||||
# https://github.com/huggingface/transformers/blob/main/src/transformers/models/detr/image_processing_detr.py#L295
|
||||
bboxes[group_id] = [xmin, ymin, w, h]
|
||||
|
||||
# compute areas
|
||||
areas = {group_id: w * h for group_id, (_, _, w, h) in bboxes.items()}
|
||||
|
||||
# generate data
|
||||
data = {
|
||||
"image_id": image_id,
|
||||
"image": str(dataset_path / image_name),
|
||||
"width": width,
|
||||
"height": height,
|
||||
"objects": [
|
||||
{
|
||||
# "category_id": "White",
|
||||
"category_id": labels[group_id],
|
||||
"image_id": image_id,
|
||||
"id": group_id,
|
||||
"area": areas[group_id],
|
||||
"bbox": bboxes[group_id],
|
||||
"iscrowd": False,
|
||||
}
|
||||
for group_id in annotation_groups
|
||||
],
|
||||
}
|
||||
|
||||
yield index, data
|
||||
index += 1
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from PIL import ImageDraw
|
||||
|
||||
# load dataset
|
||||
dataset = datasets.load_dataset("src/spheres.py", split="train")
|
||||
print("a")
|
||||
|
||||
labels = dataset.features["objects"][0]["category_id"].names
|
||||
id2label = {k: v for k, v in enumerate(labels)}
|
||||
label2id = {v: k for k, v in enumerate(labels)}
|
||||
|
||||
print(f"labels: {labels}")
|
||||
print(f"id2label: {id2label}")
|
||||
print(f"label2id: {label2id}")
|
||||
print()
|
||||
|
||||
idx = 0
|
||||
|
||||
while True:
|
||||
image = dataset[idx]["image"]
|
||||
if "DSC_4234" in image.filename:
|
||||
break
|
||||
idx += 1
|
||||
|
||||
if idx > 10000:
|
||||
break
|
||||
|
||||
print(f"image path: {image.filename}")
|
||||
print(f"data: {dataset[idx]}")
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
for obj in dataset[idx]["objects"]:
|
||||
bbox = (
|
||||
obj["bbox"][0],
|
||||
obj["bbox"][1],
|
||||
obj["bbox"][0] + obj["bbox"][2],
|
||||
obj["bbox"][1] + obj["bbox"][3],
|
||||
)
|
||||
draw.rectangle(bbox, outline="red", width=3)
|
||||
draw.text(bbox[:2], text=id2label[obj["category_id"]], fill="black")
|
||||
|
||||
# save image
|
||||
image.save("example.jpg")
|
175
src/spheres_illumination.py
Normal file
175
src/spheres_illumination.py
Normal file
|
@ -0,0 +1,175 @@
|
|||
"""Dataset class AI or NOT HuggingFace competition."""
|
||||
|
||||
import pathlib
|
||||
|
||||
import json
|
||||
import datasets
|
||||
|
||||
dataset_path_train = pathlib.Path("/home/laurent/proj-long/dataset_illumination/")
|
||||
dataset_path_test = pathlib.Path("/home/laurent/proj-long/dataset_illumination_test/")
|
||||
|
||||
_VERSION = "1.0.0"
|
||||
|
||||
_DESCRIPTION = ""
|
||||
|
||||
_HOMEPAGE = ""
|
||||
|
||||
_LICENSE = ""
|
||||
|
||||
_NAMES = [
|
||||
"Matte",
|
||||
"Shiny",
|
||||
"Chrome",
|
||||
]
|
||||
|
||||
|
||||
class spheresSynth(datasets.GeneratorBasedBuilder):
|
||||
"""spheres image dataset."""
|
||||
|
||||
def _info(self):
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
version=_VERSION,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
features=datasets.Features(
|
||||
{
|
||||
"image_id": datasets.Value("int64"),
|
||||
"image": datasets.Image(),
|
||||
"width": datasets.Value("int32"),
|
||||
"height": datasets.Value("int32"),
|
||||
"objects": [
|
||||
{
|
||||
"category_id": datasets.ClassLabel(names=_NAMES),
|
||||
"image_id": datasets.Value("int64"),
|
||||
"id": datasets.Value("string"),
|
||||
"area": datasets.Value("float32"),
|
||||
"bbox": datasets.Sequence(datasets.Value("float32"), length=4),
|
||||
"iscrowd": datasets.Value("bool"),
|
||||
}
|
||||
],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager):
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"dataset_path": dataset_path_train,
|
||||
},
|
||||
),
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TEST,
|
||||
gen_kwargs={
|
||||
"dataset_path": dataset_path_test,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
def _generate_examples(self, dataset_path: pathlib.Path):
|
||||
"""Generate images and labels for splits."""
|
||||
width = 1500
|
||||
height = 1000
|
||||
|
||||
original_width = 6020
|
||||
original_height = 4024
|
||||
|
||||
# create png iterator
|
||||
object_index = 0
|
||||
jpgs = dataset_path.rglob("*.jpg")
|
||||
for index, jpg in enumerate(jpgs):
|
||||
|
||||
# filter out probe images
|
||||
if "probes" in jpg.parts:
|
||||
continue
|
||||
|
||||
# filter out thumbnails
|
||||
if "thumb" in jpg.stem:
|
||||
continue
|
||||
|
||||
# open corresponding csv file
|
||||
json_file = jpg.parent / "meta.json"
|
||||
|
||||
# read json
|
||||
with open(json_file, "r") as f:
|
||||
meta = json.load(f)
|
||||
|
||||
gray = (
|
||||
(
|
||||
meta["gray"]["bounding_box"]["x"] / original_width * width,
|
||||
meta["gray"]["bounding_box"]["y"] / original_height * height,
|
||||
meta["gray"]["bounding_box"]["w"] / original_width * width,
|
||||
meta["gray"]["bounding_box"]["h"] / original_height * height
|
||||
),
|
||||
"Matte"
|
||||
)
|
||||
|
||||
chrome = (
|
||||
(
|
||||
meta["chrome"]["bounding_box"]["x"] / original_width * width,
|
||||
meta["chrome"]["bounding_box"]["y"] / original_height * height,
|
||||
meta["chrome"]["bounding_box"]["w"] / original_width * width,
|
||||
meta["chrome"]["bounding_box"]["h"] / original_height * height
|
||||
),
|
||||
"Chrome"
|
||||
)
|
||||
|
||||
# generate data
|
||||
data = {
|
||||
"image_id": index,
|
||||
"image": str(jpg),
|
||||
"width": width,
|
||||
"height": height,
|
||||
"objects": [
|
||||
{
|
||||
"category_id": category,
|
||||
"image_id": index,
|
||||
"id": (object_index := object_index + 1),
|
||||
"area": bbox[2] * bbox[3],
|
||||
"bbox": bbox,
|
||||
"iscrowd": False,
|
||||
}
|
||||
for bbox, category in [gray, chrome]
|
||||
],
|
||||
}
|
||||
|
||||
yield index, data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from PIL import ImageDraw
|
||||
|
||||
# load dataset
|
||||
dataset = datasets.load_dataset("src/spheres_illumination.py", split="train")
|
||||
dataset = dataset.shuffle()
|
||||
|
||||
labels = dataset.features["objects"][0]["category_id"].names
|
||||
id2label = {k: v for k, v in enumerate(labels)}
|
||||
label2id = {v: k for k, v in enumerate(labels)}
|
||||
|
||||
print(f"labels: {labels}")
|
||||
print(f"id2label: {id2label}")
|
||||
print(f"label2id: {label2id}")
|
||||
print()
|
||||
|
||||
for idx in range(10):
|
||||
image = dataset[idx]["image"]
|
||||
|
||||
print(f"image path: {image.filename}")
|
||||
print(f"data: {dataset[idx]}")
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
for obj in dataset[idx]["objects"]:
|
||||
bbox = (
|
||||
obj["bbox"][0],
|
||||
obj["bbox"][1],
|
||||
obj["bbox"][0] + obj["bbox"][2],
|
||||
obj["bbox"][1] + obj["bbox"][3],
|
||||
)
|
||||
draw.rectangle(bbox, outline="red", width=3)
|
||||
draw.text(bbox[:2], text=id2label[obj["category_id"]], fill="black")
|
||||
|
||||
# save image
|
||||
image.save(f"example_{idx}.jpg")
|
113
src/spheres_predict.py
Normal file
113
src/spheres_predict.py
Normal file
|
@ -0,0 +1,113 @@
|
|||
"""Dataset class AI or NOT HuggingFace competition."""
|
||||
|
||||
import pathlib
|
||||
|
||||
import datasets
|
||||
|
||||
dataset_path = pathlib.Path("/home/laurent/proj-long/dataset_predict/")
|
||||
|
||||
_VERSION = "1.0.0"
|
||||
|
||||
_DESCRIPTION = ""
|
||||
|
||||
_HOMEPAGE = ""
|
||||
|
||||
_LICENSE = ""
|
||||
|
||||
_NAMES = [
|
||||
"Matte",
|
||||
"Shiny",
|
||||
"Chrome",
|
||||
]
|
||||
|
||||
|
||||
class spheresSynth(datasets.GeneratorBasedBuilder):
|
||||
"""spheres image dataset."""
|
||||
|
||||
def _info(self):
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
version=_VERSION,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
features=datasets.Features(
|
||||
{
|
||||
"image_id": datasets.Value("int64"),
|
||||
"image": datasets.Image(),
|
||||
"objects": [
|
||||
{
|
||||
"category_id": datasets.ClassLabel(names=_NAMES),
|
||||
"image_id": datasets.Value("int64"),
|
||||
"id": datasets.Value("string"),
|
||||
"area": datasets.Value("float32"),
|
||||
"bbox": datasets.Sequence(datasets.Value("float32"), length=4),
|
||||
"iscrowd": datasets.Value("bool"),
|
||||
}
|
||||
],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager):
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"dataset_path": dataset_path,
|
||||
},
|
||||
)
|
||||
]
|
||||
|
||||
def _generate_examples(self, dataset_path: pathlib.Path):
|
||||
"""Generate images and labels for splits."""
|
||||
# create png iterator
|
||||
jpgs = dataset_path.rglob("*.jpg")
|
||||
|
||||
for index, jpg in enumerate(jpgs):
|
||||
|
||||
print(index, jpg, 2)
|
||||
|
||||
# generate data
|
||||
data = {
|
||||
"image_id": index,
|
||||
"image": str(jpg),
|
||||
"objects": [],
|
||||
}
|
||||
|
||||
yield index, data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from PIL import ImageDraw
|
||||
|
||||
# load dataset
|
||||
dataset = datasets.load_dataset("src/spheres_predict.py", split="train")
|
||||
|
||||
labels = dataset.features["objects"][0]["category_id"].names
|
||||
id2label = {k: v for k, v in enumerate(labels)}
|
||||
label2id = {v: k for k, v in enumerate(labels)}
|
||||
|
||||
print(f"labels: {labels}")
|
||||
print(f"id2label: {id2label}")
|
||||
print(f"label2id: {label2id}")
|
||||
print()
|
||||
|
||||
for idx in range(10):
|
||||
image = dataset[idx]["image"]
|
||||
|
||||
print(f"image path: {image.filename}")
|
||||
print(f"data: {dataset[idx]}")
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
for obj in dataset[idx]["objects"]:
|
||||
bbox = (
|
||||
obj["bbox"][0],
|
||||
obj["bbox"][1],
|
||||
obj["bbox"][0] + obj["bbox"][2],
|
||||
obj["bbox"][1] + obj["bbox"][3],
|
||||
)
|
||||
draw.rectangle(bbox, outline="red", width=3)
|
||||
draw.text(bbox[:2], text=id2label[obj["category_id"]], fill="black")
|
||||
|
||||
# save image
|
||||
image.save(f"example_{idx}.jpg")
|
174
src/spheres_synth.py
Normal file
174
src/spheres_synth.py
Normal file
|
@ -0,0 +1,174 @@
|
|||
"""Dataset class AI or NOT HuggingFace competition."""
|
||||
|
||||
import pathlib
|
||||
|
||||
import cv2
|
||||
import datasets
|
||||
import numpy as np
|
||||
|
||||
dataset_path = pathlib.Path("/home/laurent/proj-long/dataset_render/")
|
||||
|
||||
_VERSION = "1.0.0"
|
||||
|
||||
_DESCRIPTION = ""
|
||||
|
||||
_HOMEPAGE = ""
|
||||
|
||||
_LICENSE = ""
|
||||
|
||||
_NAMES = [
|
||||
"Matte",
|
||||
"Shiny",
|
||||
"Chrome",
|
||||
]
|
||||
|
||||
|
||||
class spheresSynth(datasets.GeneratorBasedBuilder):
|
||||
"""spheres image dataset."""
|
||||
|
||||
def _info(self):
|
||||
return datasets.DatasetInfo(
|
||||
description=_DESCRIPTION,
|
||||
version=_VERSION,
|
||||
homepage=_HOMEPAGE,
|
||||
license=_LICENSE,
|
||||
features=datasets.Features(
|
||||
{
|
||||
"image_id": datasets.Value("int64"),
|
||||
"image": datasets.Image(),
|
||||
"width": datasets.Value("int32"),
|
||||
"height": datasets.Value("int32"),
|
||||
"objects": [
|
||||
{
|
||||
"category_id": datasets.ClassLabel(names=_NAMES),
|
||||
"image_id": datasets.Value("int64"),
|
||||
"id": datasets.Value("string"),
|
||||
"area": datasets.Value("float32"),
|
||||
"bbox": datasets.Sequence(datasets.Value("float32"), length=4),
|
||||
"iscrowd": datasets.Value("bool"),
|
||||
}
|
||||
],
|
||||
}
|
||||
),
|
||||
)
|
||||
|
||||
def _split_generators(self, dl_manager):
|
||||
return [
|
||||
datasets.SplitGenerator(
|
||||
name=datasets.Split.TRAIN,
|
||||
gen_kwargs={
|
||||
"dataset_path": dataset_path,
|
||||
},
|
||||
),
|
||||
]
|
||||
|
||||
def _generate_examples(self, dataset_path: pathlib.Path):
|
||||
"""Generate images and labels for splits."""
|
||||
# create png iterator
|
||||
width = 1200
|
||||
height = 675
|
||||
object_index = 0
|
||||
pngs = dataset_path.glob("*.png")
|
||||
for index, png in enumerate(pngs):
|
||||
# open corresponding csv file
|
||||
csv = dataset_path / (png.stem + ".csv")
|
||||
|
||||
# read csv lines
|
||||
with open(csv, "r") as f:
|
||||
lines = f.readlines()
|
||||
lines = [line.strip().split(",") for line in lines]
|
||||
lines = [
|
||||
(
|
||||
float(line[0]),
|
||||
1 - float(line[1]),
|
||||
float(line[2]),
|
||||
1 - float(line[3]),
|
||||
line[4].strip()
|
||||
) for line in lines
|
||||
]
|
||||
|
||||
bboxes = [
|
||||
(
|
||||
line[0] * width,
|
||||
line[3] * height,
|
||||
(line[2] - line[0]) * width,
|
||||
(line[1] - line[3]) * height,
|
||||
)
|
||||
for line in lines
|
||||
]
|
||||
|
||||
categories = []
|
||||
for line in lines:
|
||||
category = line[4]
|
||||
|
||||
if category == "White":
|
||||
category = "Matte"
|
||||
elif category == "Black":
|
||||
category = "Shiny"
|
||||
elif category == "Grey":
|
||||
category = "Matte"
|
||||
elif category == "Red":
|
||||
category = "Shiny"
|
||||
elif category == "Chrome":
|
||||
category = "Chrome"
|
||||
elif category == "Cyan":
|
||||
category = "Shiny"
|
||||
|
||||
categories.append(category)
|
||||
|
||||
# generate data
|
||||
data = {
|
||||
"image_id": index,
|
||||
"image": str(png),
|
||||
"width": width,
|
||||
"height": height,
|
||||
"objects": [
|
||||
{
|
||||
"category_id": category,
|
||||
"image_id": index,
|
||||
"id": (object_index := object_index + 1),
|
||||
"area": bbox[2] * bbox[3],
|
||||
"bbox": bbox,
|
||||
"iscrowd": False,
|
||||
}
|
||||
for bbox, category in zip(bboxes, categories)
|
||||
],
|
||||
}
|
||||
|
||||
yield index, data
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
from PIL import ImageDraw
|
||||
|
||||
# load dataset
|
||||
dataset = datasets.load_dataset("src/spheres_synth.py", split="train")
|
||||
|
||||
labels = dataset.features["objects"][0]["category_id"].names
|
||||
id2label = {k: v for k, v in enumerate(labels)}
|
||||
label2id = {v: k for k, v in enumerate(labels)}
|
||||
|
||||
print(f"labels: {labels}")
|
||||
print(f"id2label: {id2label}")
|
||||
print(f"label2id: {label2id}")
|
||||
print()
|
||||
|
||||
for idx in range(10):
|
||||
image = dataset[idx]["image"]
|
||||
|
||||
# print(f"image path: {image.filename}")
|
||||
# print(f"data: {dataset[idx]}")
|
||||
|
||||
draw = ImageDraw.Draw(image)
|
||||
for obj in dataset[idx]["objects"]:
|
||||
bbox = (
|
||||
obj["bbox"][0],
|
||||
obj["bbox"][1],
|
||||
obj["bbox"][0] + obj["bbox"][2],
|
||||
obj["bbox"][1] + obj["bbox"][3],
|
||||
)
|
||||
draw.rectangle(bbox, outline="red", width=3)
|
||||
draw.text(bbox[:2], text=id2label[obj["category_id"]], fill="black")
|
||||
|
||||
# save image
|
||||
image.save(f"example_{idx}.jpg")
|
52
src/tmp.py
52
src/tmp.py
|
@ -1,52 +0,0 @@
|
|||
from pathlib import Path
|
||||
from threading import Thread
|
||||
|
||||
import albumentations as A
|
||||
import numpy as np
|
||||
import torchvision.transforms as T
|
||||
|
||||
from data.dataset import SyntheticDataset
|
||||
from utils import RandomPaste
|
||||
|
||||
transform = A.Compose(
|
||||
[
|
||||
A.LongestMaxSize(max_size=1024),
|
||||
A.Flip(),
|
||||
RandomPaste(5, "/media/disk1/lfainsin/SPHERES/WHITE", "/dev/null"),
|
||||
A.ToGray(p=0.01),
|
||||
A.ISONoise(),
|
||||
A.ImageCompression(),
|
||||
],
|
||||
)
|
||||
|
||||
dataset = SyntheticDataset(image_dir="/media/disk1/lfainsin/BACKGROUND/coco/", transform=transform)
|
||||
transform = T.ToPILImage()
|
||||
|
||||
|
||||
def render(i, image, mask):
|
||||
image = transform(image)
|
||||
mask = transform(mask)
|
||||
|
||||
path = f"/media/disk1/lfainsin/TRAIN_prerender/{i:06d}/"
|
||||
Path(path).mkdir(parents=True, exist_ok=True)
|
||||
|
||||
image.save(f"{path}/image.jpg")
|
||||
mask.save(f"{path}/MASK.PNG")
|
||||
|
||||
|
||||
def renderlist(list_i, dataset):
|
||||
for i in list_i:
|
||||
image, mask = dataset[i]
|
||||
render(i, image, mask)
|
||||
|
||||
|
||||
sublists = np.array_split(range(len(dataset)), 16 * 5)
|
||||
threads = []
|
||||
for sublist in sublists:
|
||||
t = Thread(target=renderlist, args=(sublist, dataset))
|
||||
t.start()
|
||||
threads.append(t)
|
||||
|
||||
# join all threads
|
||||
for t in threads:
|
||||
t.join()
|
83
src/train.py
83
src/train.py
|
@ -1,83 +0,0 @@
|
|||
"""Main script, to be launched to start the fine tuning of the neural network."""
|
||||
|
||||
import pytorch_lightning as pl
|
||||
import wandb
|
||||
from pytorch_lightning.callbacks import (
|
||||
EarlyStopping,
|
||||
LearningRateMonitor,
|
||||
ModelCheckpoint,
|
||||
RichModelSummary,
|
||||
RichProgressBar,
|
||||
)
|
||||
from pytorch_lightning.loggers import WandbLogger
|
||||
|
||||
from data import Spheres
|
||||
from modules import MRCNNModule
|
||||
from utils.callback import TableLog
|
||||
|
||||
if __name__ == "__main__":
|
||||
# setup wandb
|
||||
logger = WandbLogger(
|
||||
project="Mask R-CNN",
|
||||
config="wandb.yaml",
|
||||
save_dir="/tmp/",
|
||||
log_model="all",
|
||||
settings=wandb.Settings(
|
||||
code_dir="./src/",
|
||||
),
|
||||
)
|
||||
|
||||
# seed random generators
|
||||
pl.seed_everything(
|
||||
seed=wandb.config.SEED,
|
||||
workers=True,
|
||||
)
|
||||
|
||||
# Create Network
|
||||
module = MRCNNModule(
|
||||
n_classes=2,
|
||||
)
|
||||
|
||||
# load checkpoint
|
||||
# module.load_from_checkpoint("/tmp/model.ckpt")
|
||||
|
||||
# log gradients and weights regularly
|
||||
logger.watch(
|
||||
model=module.model,
|
||||
log="all",
|
||||
)
|
||||
|
||||
# Create the dataloaders
|
||||
datamodule = Spheres()
|
||||
|
||||
# Create the trainer
|
||||
trainer = pl.Trainer(
|
||||
max_epochs=wandb.config.EPOCHS,
|
||||
accelerator=wandb.config.DEVICE,
|
||||
benchmark=wandb.config.BENCHMARK,
|
||||
deterministic=wandb.config.DETERMINISTIC,
|
||||
precision=wandb.config.PRECISION,
|
||||
logger=logger,
|
||||
log_every_n_steps=5,
|
||||
val_check_interval=250,
|
||||
callbacks=[
|
||||
EarlyStopping(monitor="valid/sum/map", mode="max", patience=10, min_delta=0.01),
|
||||
ModelCheckpoint(monitor="valid/sum/map", mode="max"),
|
||||
# ModelPruning("l1_unstructured", amount=0.5),
|
||||
LearningRateMonitor(log_momentum=True),
|
||||
# StochasticWeightAveraging(swa_lrs=1e-2),
|
||||
RichModelSummary(max_depth=2),
|
||||
RichProgressBar(),
|
||||
TableLog(),
|
||||
],
|
||||
# profiler="advanced",
|
||||
gradient_clip_val=1,
|
||||
num_sanity_val_steps=3,
|
||||
devices=[0],
|
||||
)
|
||||
|
||||
# actually train the model
|
||||
trainer.fit(model=module, datamodule=datamodule)
|
||||
|
||||
# stop wandb
|
||||
wandb.run.finish() # type: ignore
|
|
@ -1,2 +0,0 @@
|
|||
from .callback import TableLog
|
||||
from .paste import RandomPaste
|
|
@ -1,136 +0,0 @@
|
|||
import wandb
|
||||
from pytorch_lightning.callbacks import Callback
|
||||
|
||||
columns = [
|
||||
"image",
|
||||
]
|
||||
class_labels = {
|
||||
1: "sphere",
|
||||
2: "chrome",
|
||||
10: "sphere_gt",
|
||||
20: "chrome_gt",
|
||||
}
|
||||
|
||||
|
||||
class TableLog(Callback):
|
||||
def on_train_batch_end(self, trainer, pl_module, outputs, batch, batch_idx):
|
||||
if batch_idx == 0:
|
||||
rows = []
|
||||
|
||||
# unpacking
|
||||
images, targets = batch
|
||||
|
||||
for image, target in zip(
|
||||
images,
|
||||
targets,
|
||||
):
|
||||
rows.append(
|
||||
[
|
||||
wandb.Image(
|
||||
image.cpu(),
|
||||
masks={
|
||||
"ground_truth": {
|
||||
"mask_data": (target["masks"] * target["labels"][:, None, None])
|
||||
.max(dim=0)
|
||||
.values.mul(10)
|
||||
.cpu()
|
||||
.numpy(),
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
wandb.log(
|
||||
{
|
||||
"train/predictions": wandb.Table(
|
||||
columns=columns,
|
||||
data=rows,
|
||||
)
|
||||
}
|
||||
)
|
||||
|
||||
def on_validation_epoch_start(self, trainer, pl_module):
|
||||
self.rows = []
|
||||
|
||||
def on_validation_batch_end(self, trainer, pl_module, outputs, batch, batch_idx, dataloader_idx):
|
||||
if batch_idx == 2:
|
||||
# unpacking
|
||||
images, targets = batch
|
||||
|
||||
for image, target, pred in zip(
|
||||
images,
|
||||
targets,
|
||||
outputs,
|
||||
):
|
||||
box_data_gt = [
|
||||
{
|
||||
"position": {
|
||||
"minX": int(target["boxes"][j][0]),
|
||||
"minY": int(target["boxes"][j][1]),
|
||||
"maxX": int(target["boxes"][j][2]),
|
||||
"maxY": int(target["boxes"][j][3]),
|
||||
},
|
||||
"domain": "pixel",
|
||||
"class_id": int(target["labels"][j] * 10),
|
||||
"class_labels": class_labels,
|
||||
}
|
||||
for j in range(len(target["labels"]))
|
||||
]
|
||||
|
||||
box_data = [
|
||||
{
|
||||
"position": {
|
||||
"minX": int(pred["boxes"][j][0]),
|
||||
"minY": int(pred["boxes"][j][1]),
|
||||
"maxX": int(pred["boxes"][j][2]),
|
||||
"maxY": int(pred["boxes"][j][3]),
|
||||
},
|
||||
"domain": "pixel",
|
||||
"class_id": int(pred["labels"][j]),
|
||||
"box_caption": f"{pred['scores'][j]:0.3f}",
|
||||
"class_labels": class_labels,
|
||||
}
|
||||
for j in range(len(pred["labels"]))
|
||||
]
|
||||
|
||||
self.rows.append(
|
||||
[
|
||||
wandb.Image(
|
||||
image.cpu(),
|
||||
masks={
|
||||
"ground_truth": {
|
||||
"mask_data": (target["masks"] * target["labels"][:, None, None])
|
||||
.max(dim=0)
|
||||
.values.mul(10)
|
||||
.cpu()
|
||||
.numpy(),
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
"predictions": {
|
||||
"mask_data": (pred["masks"] * pred["labels"][:, None, None])
|
||||
.max(dim=0)
|
||||
.values.cpu()
|
||||
.numpy(),
|
||||
"class_labels": class_labels,
|
||||
},
|
||||
},
|
||||
boxes={
|
||||
"ground_truth": {"box_data": box_data_gt},
|
||||
"predictions": {"box_data": box_data},
|
||||
},
|
||||
),
|
||||
]
|
||||
)
|
||||
|
||||
def on_validation_epoch_end(self, trainer, pl_module):
|
||||
# log table
|
||||
wandb.log(
|
||||
{
|
||||
"valid/predictions": wandb.Table(
|
||||
columns=columns,
|
||||
data=self.rows,
|
||||
)
|
||||
}
|
||||
)
|
|
@ -1,235 +0,0 @@
|
|||
from __future__ import annotations
|
||||
|
||||
import random as rd
|
||||
from dataclasses import dataclass
|
||||
from pathlib import Path
|
||||
from typing import List, Tuple
|
||||
|
||||
import albumentations as A
|
||||
import numpy as np
|
||||
import torchvision.transforms as T
|
||||
from PIL import Image
|
||||
|
||||
|
||||
class RandomPaste(A.DualTransform):
|
||||
"""Paste an object on a background.
|
||||
|
||||
Args:
|
||||
TODO
|
||||
|
||||
Targets:
|
||||
image, mask
|
||||
|
||||
Image types:
|
||||
uint8
|
||||
"""
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
nb,
|
||||
sphere_image_dir,
|
||||
chrome_sphere_image_dir,
|
||||
scale_range=(0.05, 0.3),
|
||||
always_apply=True,
|
||||
p=1.0,
|
||||
):
|
||||
super().__init__(always_apply, p)
|
||||
|
||||
self.sphere_images = []
|
||||
self.sphere_images.extend(list(Path(sphere_image_dir).glob("**/*.jpg")))
|
||||
self.sphere_images.extend(list(Path(sphere_image_dir).glob("**/*.png")))
|
||||
|
||||
self.chrome_sphere_images = []
|
||||
self.chrome_sphere_images.extend(list(Path(chrome_sphere_image_dir).glob("**/*.jpg")))
|
||||
self.chrome_sphere_images.extend(list(Path(chrome_sphere_image_dir).glob("**/*.png")))
|
||||
|
||||
self.scale_range = scale_range
|
||||
self.nb = nb
|
||||
|
||||
@property
|
||||
def targets_as_params(self):
|
||||
return ["image"]
|
||||
|
||||
def apply(self, img, augmentation_datas, **params):
|
||||
# convert img to Image, needed for `paste` function
|
||||
img = Image.fromarray(img)
|
||||
|
||||
# paste spheres
|
||||
for augmentation in augmentation_datas:
|
||||
paste_img_aug = T.functional.adjust_contrast(
|
||||
augmentation.paste_img,
|
||||
contrast_factor=augmentation.contrast,
|
||||
)
|
||||
paste_img_aug = T.functional.adjust_brightness(
|
||||
paste_img_aug,
|
||||
brightness_factor=augmentation.brightness,
|
||||
)
|
||||
paste_img_aug = T.functional.affine(
|
||||
paste_img_aug,
|
||||
scale=0.95,
|
||||
translate=(0, 0),
|
||||
angle=augmentation.angle,
|
||||
shear=augmentation.shear,
|
||||
interpolation=T.InterpolationMode.BICUBIC,
|
||||
)
|
||||
paste_img_aug = T.functional.resize(
|
||||
paste_img_aug,
|
||||
size=augmentation.shape,
|
||||
interpolation=T.InterpolationMode.LANCZOS,
|
||||
)
|
||||
|
||||
paste_mask_aug = T.functional.affine(
|
||||
augmentation.paste_mask,
|
||||
scale=0.95,
|
||||
translate=(0, 0),
|
||||
angle=augmentation.angle,
|
||||
shear=augmentation.shear,
|
||||
interpolation=T.InterpolationMode.BICUBIC,
|
||||
)
|
||||
paste_mask_aug = T.functional.resize(
|
||||
paste_mask_aug,
|
||||
size=augmentation.shape,
|
||||
interpolation=T.InterpolationMode.LANCZOS,
|
||||
)
|
||||
|
||||
img.paste(paste_img_aug, augmentation.position, paste_mask_aug)
|
||||
|
||||
return np.array(img.convert("RGB"))
|
||||
|
||||
def apply_to_mask(self, mask, augmentation_datas, **params):
|
||||
# convert mask to Image, needed for `paste` function
|
||||
mask = Image.fromarray(mask)
|
||||
|
||||
for augmentation in augmentation_datas:
|
||||
paste_mask_aug = T.functional.affine(
|
||||
augmentation.paste_mask,
|
||||
scale=0.95,
|
||||
translate=(0, 0),
|
||||
angle=augmentation.angle,
|
||||
shear=augmentation.shear,
|
||||
interpolation=T.InterpolationMode.BICUBIC,
|
||||
)
|
||||
paste_mask_aug = T.functional.resize(
|
||||
paste_mask_aug,
|
||||
size=augmentation.shape,
|
||||
interpolation=T.InterpolationMode.LANCZOS,
|
||||
)
|
||||
|
||||
# binarize the mask -> {0, 1}
|
||||
paste_mask_aug_bin = paste_mask_aug.point(lambda p: augmentation.value if p > 10 else 0)
|
||||
|
||||
mask.paste(paste_mask_aug, augmentation.position, paste_mask_aug_bin)
|
||||
|
||||
return np.array(mask.convert("L"))
|
||||
|
||||
def get_params_dependent_on_targets(self, params):
|
||||
# init augmentation list
|
||||
augmentation_datas: List[AugmentationData] = []
|
||||
|
||||
# load target image (w/ transparency)
|
||||
target_img = params["image"]
|
||||
target_shape = np.array(target_img.shape[:2], dtype=np.uint)
|
||||
|
||||
# generate augmentations
|
||||
ite = 0
|
||||
NB = rd.randint(1, self.nb)
|
||||
while len(augmentation_datas) < NB:
|
||||
if ite > 100:
|
||||
break
|
||||
else:
|
||||
ite += 1
|
||||
|
||||
# choose a random sphere image and its corresponding mask
|
||||
if rd.random() > 0.5 or len(self.chrome_sphere_images) == 0:
|
||||
img_path = rd.choice(self.sphere_images)
|
||||
value = len(augmentation_datas) + 1
|
||||
else:
|
||||
img_path = rd.choice(self.chrome_sphere_images)
|
||||
value = 255 - len(augmentation_datas)
|
||||
mask_path = img_path.parent.joinpath("MASK.PNG")
|
||||
|
||||
# load paste assets
|
||||
paste_img = Image.open(img_path).convert("RGBA")
|
||||
paste_shape = np.array(paste_img.size, dtype=np.uint)
|
||||
paste_mask = Image.open(mask_path).convert("LA")
|
||||
|
||||
# compute minimum scaling to fit inside target
|
||||
min_scale = np.min(target_shape / paste_shape)
|
||||
|
||||
# randomly scale image inside target
|
||||
scale = rd.uniform(*self.scale_range) * min_scale
|
||||
shape = np.array(paste_shape * scale, dtype=np.uint)
|
||||
|
||||
try:
|
||||
augmentation_datas.append(
|
||||
AugmentationData(
|
||||
position=(
|
||||
rd.randint(0, target_shape[1] - shape[1]),
|
||||
rd.randint(0, target_shape[0] - shape[0]),
|
||||
),
|
||||
shear=(
|
||||
rd.uniform(-2, 2),
|
||||
rd.uniform(-2, 2),
|
||||
),
|
||||
shape=tuple(shape),
|
||||
angle=rd.uniform(0, 360),
|
||||
brightness=rd.uniform(0.8, 1.2),
|
||||
contrast=rd.uniform(0.8, 1.2),
|
||||
paste_img=paste_img,
|
||||
paste_mask=paste_mask,
|
||||
value=value,
|
||||
target_shape=tuple(target_shape),
|
||||
other_augmentations=augmentation_datas,
|
||||
)
|
||||
)
|
||||
except ValueError:
|
||||
continue
|
||||
|
||||
params.update(
|
||||
{
|
||||
"augmentation_datas": augmentation_datas,
|
||||
}
|
||||
)
|
||||
|
||||
return params
|
||||
|
||||
|
||||
@dataclass
|
||||
class AugmentationData:
|
||||
"""Store data for pasting augmentation."""
|
||||
|
||||
position: Tuple[int, int]
|
||||
|
||||
shape: Tuple[int, int]
|
||||
target_shape: Tuple[int, int]
|
||||
angle: float
|
||||
|
||||
brightness: float
|
||||
contrast: float
|
||||
|
||||
shear: Tuple[float, float]
|
||||
|
||||
paste_img: Image.Image
|
||||
paste_mask: Image.Image
|
||||
value: int
|
||||
|
||||
other_augmentations: List[AugmentationData]
|
||||
|
||||
def __post_init__(self) -> None:
|
||||
# check for overlapping
|
||||
if overlap(self.other_augmentations, self):
|
||||
raise ValueError
|
||||
|
||||
|
||||
def overlap(augmentations: List[AugmentationData], augmentation: AugmentationData) -> bool:
|
||||
x1, y1 = augmentation.position
|
||||
w1, h1 = augmentation.shape
|
||||
|
||||
for other_augmentation in augmentations:
|
||||
x2, y2 = other_augmentation.position
|
||||
w2, h2 = other_augmentation.shape
|
||||
|
||||
if x1 + w1 >= x2 and x1 <= x2 + w2 and y1 + h1 >= y2 and y1 <= y2 + h2:
|
||||
return True
|
||||
|
||||
return False
|
44
wandb.yaml
44
wandb.yaml
|
@ -1,44 +0,0 @@
|
|||
DIR_TRAIN_IMG:
|
||||
value: "/media/disk1/lfainsin/TRAIN_prerender/"
|
||||
DIR_VALID_IMG:
|
||||
value: "/media/disk1/lfainsin/TEST_tmp_mrcnn/"
|
||||
# DIR_SPHERE:
|
||||
# value: "/media/disk1/lfainsin/SPHERES/"
|
||||
|
||||
N_CHANNELS:
|
||||
value: 3
|
||||
N_CLASSES:
|
||||
value: 1
|
||||
|
||||
AMP:
|
||||
value: True
|
||||
PIN_MEMORY:
|
||||
value: True
|
||||
BENCHMARK:
|
||||
value: True
|
||||
DETERMINISTIC:
|
||||
value: False
|
||||
PRECISION:
|
||||
value: 32
|
||||
SEED:
|
||||
value: 69420
|
||||
DEVICE:
|
||||
value: gpu
|
||||
WORKERS:
|
||||
value: 16
|
||||
|
||||
EPOCHS:
|
||||
value: 50
|
||||
TRAIN_BATCH_SIZE:
|
||||
value: 6
|
||||
VALID_BATCH_SIZE:
|
||||
value: 2
|
||||
PREFETCH_FACTOR:
|
||||
value: 2
|
||||
|
||||
LEARNING_RATE:
|
||||
value: 0.001
|
||||
WEIGHT_DECAY:
|
||||
value: 0.0001
|
||||
MOMENTUM:
|
||||
value: 0.9
|
Loading…
Reference in a new issue