feat(python): added some typing and comments/docstrings
This commit is contained in:
parent
604047f8eb
commit
5e77fa2833
|
@ -1,3 +1,5 @@
|
||||||
|
"""Emulator server, responsible for handling user inputs and outputting video & sound."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import logging
|
import logging
|
||||||
import os
|
import os
|
||||||
|
@ -30,8 +32,8 @@ from settings import (
|
||||||
RTMP_STREAM_URI,
|
RTMP_STREAM_URI,
|
||||||
)
|
)
|
||||||
|
|
||||||
core: mgba.core = mgba.core.load_path(EMULATOR_ROM_PATH)
|
core: mgba.core.Core = mgba.core.load_path(EMULATOR_ROM_PATH)
|
||||||
screen: mgba.image = mgba.image.Image(EMULATOR_WIDTH, EMULATOR_HEIGHT)
|
screen: mgba.image.Image = mgba.image.Image(EMULATOR_WIDTH, EMULATOR_HEIGHT)
|
||||||
core.set_video_buffer(screen)
|
core.set_video_buffer(screen)
|
||||||
core.reset()
|
core.reset()
|
||||||
|
|
||||||
|
@ -41,20 +43,7 @@ mgba.log.silence()
|
||||||
r: redis.Redis = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=0)
|
r: redis.Redis = redis.Redis(host=REDIS_HOST, port=REDIS_PORT, db=0)
|
||||||
|
|
||||||
|
|
||||||
def next_action():
|
# Launch ffmpeg process
|
||||||
"""Select the next key from the redis database.
|
|
||||||
|
|
||||||
Returns:
|
|
||||||
int: key used by mgba
|
|
||||||
"""
|
|
||||||
votes: list[int] = list(map(int, r.mget(KEYS_ID)))
|
|
||||||
if any(votes):
|
|
||||||
r.mset(KEYS_RESET)
|
|
||||||
return votes.index(max(votes))
|
|
||||||
else:
|
|
||||||
return -1
|
|
||||||
|
|
||||||
|
|
||||||
stream = Popen(
|
stream = Popen(
|
||||||
[
|
[
|
||||||
"/usr/bin/ffmpeg",
|
"/usr/bin/ffmpeg",
|
||||||
|
@ -91,13 +80,37 @@ stream = Popen(
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def next_action():
|
||||||
|
"""Select the next key from the redis database.
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
int: key used by mgba.
|
||||||
|
"""
|
||||||
|
votes: list[int] = list(map(int, r.mget(KEYS_ID)))
|
||||||
|
if any(votes):
|
||||||
|
r.mset(KEYS_RESET)
|
||||||
|
return votes.index(max(votes))
|
||||||
|
else:
|
||||||
|
return -1
|
||||||
|
|
||||||
|
|
||||||
def state_manager(loop: asyncio.AbstractEventLoop):
|
def state_manager(loop: asyncio.AbstractEventLoop):
|
||||||
|
"""Subscribe and respond to messages received from redis.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
loop (asyncio.AbstractEventLoop): the asyncio event loop.
|
||||||
|
"""
|
||||||
ps = r.pubsub()
|
ps = r.pubsub()
|
||||||
ps.subscribe("admin")
|
ps.subscribe("admin")
|
||||||
|
|
||||||
while True:
|
while True:
|
||||||
for message in ps.listen():
|
for message in ps.listen():
|
||||||
if message["type"] == "message":
|
if message["type"] == "message":
|
||||||
data = message["data"].decode("utf-8")
|
data = message["data"].decode("utf-8")
|
||||||
|
|
||||||
|
# TODO: voir si plus clean possible ?
|
||||||
|
# TODO: dev dans un docker ?
|
||||||
|
|
||||||
if data == "save":
|
if data == "save":
|
||||||
asyncio.ensure_future(utils.save(core), loop=loop)
|
asyncio.ensure_future(utils.save(core), loop=loop)
|
||||||
elif data.startswith("load:"):
|
elif data.startswith("load:"):
|
||||||
|
@ -105,27 +118,32 @@ def state_manager(loop: asyncio.AbstractEventLoop):
|
||||||
|
|
||||||
|
|
||||||
async def emulator():
|
async def emulator():
|
||||||
|
"""Start the main loop responsible for handling inputs and sending images to ffmpeg."""
|
||||||
while True:
|
while True:
|
||||||
last_frame_t = time.time()
|
last_frame_t = time.time()
|
||||||
|
|
||||||
|
# poll redis for keys
|
||||||
if not (core.frame_counter % EMULATOR_POLLING_RATE):
|
if not (core.frame_counter % EMULATOR_POLLING_RATE):
|
||||||
core.clear_keys(*KEYS_MGBA)
|
core.clear_keys(*KEYS_MGBA)
|
||||||
next_key = next_action()
|
next_key = next_action()
|
||||||
if next_key != -1:
|
if next_key != -1:
|
||||||
core.set_keys(next_key)
|
core.set_keys(next_key)
|
||||||
|
|
||||||
|
# mGBA run next frame
|
||||||
core.run_frame()
|
core.run_frame()
|
||||||
|
|
||||||
|
# save frame to PNG image
|
||||||
image = screen.to_pil().convert("RGB")
|
image = screen.to_pil().convert("RGB")
|
||||||
image.save(stream.stdin, "PNG")
|
image.save(stream.stdin, "PNG")
|
||||||
|
|
||||||
|
# sleep until next frame, if necessary
|
||||||
sleep_t = last_frame_t - time.time() + EMULATOR_SPF
|
sleep_t = last_frame_t - time.time() + EMULATOR_SPF
|
||||||
if sleep_t > 0:
|
if sleep_t > 0:
|
||||||
await asyncio.sleep(sleep_t)
|
await asyncio.sleep(sleep_t)
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
"""Start the emulator."""
|
||||||
loop = asyncio.get_event_loop()
|
loop = asyncio.get_event_loop()
|
||||||
|
|
||||||
# setup states in redis
|
# setup states in redis
|
||||||
|
@ -145,3 +163,5 @@ async def main():
|
||||||
|
|
||||||
if __name__ == "__main__":
|
if __name__ == "__main__":
|
||||||
asyncio.run(main())
|
asyncio.run(main())
|
||||||
|
|
||||||
|
# TODO: write code when ctrl+C -> save redis database ?
|
||||||
|
|
|
@ -1,16 +1,18 @@
|
||||||
|
"""Websocket server, responsible for proxying user inputs."""
|
||||||
|
|
||||||
import asyncio
|
import asyncio
|
||||||
import json
|
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from typing import Union
|
|
||||||
|
|
||||||
import redis
|
import redis
|
||||||
import websockets
|
import websockets
|
||||||
|
import websockets.exceptions
|
||||||
|
import websockets.server
|
||||||
|
import websockets.typing
|
||||||
|
|
||||||
from settings import (
|
from settings import (
|
||||||
KEYS_ID,
|
KEYS_ID,
|
||||||
KEYS_RESET,
|
KEYS_RESET,
|
||||||
PASSWORD_ADMIN,
|
|
||||||
REDIS_HOST,
|
REDIS_HOST,
|
||||||
REDIS_PORT,
|
REDIS_PORT,
|
||||||
USER_TIMEOUT,
|
USER_TIMEOUT,
|
||||||
|
@ -27,71 +29,46 @@ r.mset(KEYS_RESET)
|
||||||
USERS: Users = Users()
|
USERS: Users = Users()
|
||||||
|
|
||||||
|
|
||||||
async def parse_message(user: User, message: dict[str, str]) -> None:
|
async def parse_message(user: User, message: websockets.typing.Data) -> None:
|
||||||
"""Parse the user's message.
|
"""Parse the user's message.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
user (User): the sender of the message.
|
user (User): the sender of the message.
|
||||||
message (dict[str, str]): the data received (through the websocket).
|
message (str): the key received (through the websocket).
|
||||||
"""
|
"""
|
||||||
if "auth" in message:
|
if user.last_message + USER_TIMEOUT > time.time():
|
||||||
data = message["auth"]
|
logging.debug(f"dropping action: {message!r} from {user}")
|
||||||
if USERS.admin is None and data == PASSWORD_ADMIN:
|
return None
|
||||||
USERS.admin = user
|
elif message in KEYS_ID:
|
||||||
logging.debug(f"admin authenticated: {user}")
|
r.incr(message)
|
||||||
|
user.last_message = time.time()
|
||||||
response: dict[str, Union[str, list[str]]] = dict()
|
logging.debug(f"received action: {message!r} from {user}")
|
||||||
response["auth"] = "success"
|
else:
|
||||||
states = r.smembers("states")
|
logging.error(f"unsupported action: {message!r} from {user}")
|
||||||
stringlist = [x.decode("utf-8") for x in states]
|
|
||||||
response["states"] = sorted(stringlist)
|
|
||||||
await user.send(json.dumps(response))
|
|
||||||
|
|
||||||
if "admin" in message:
|
|
||||||
if user == USERS.admin:
|
|
||||||
data = message["admin"]
|
|
||||||
if data == "save":
|
|
||||||
r.publish("admin", "save")
|
|
||||||
elif data.startswith("load:"):
|
|
||||||
r.publish("admin", data)
|
|
||||||
else:
|
|
||||||
logging.error(f"unsupported admin action: {data}")
|
|
||||||
else:
|
|
||||||
logging.error(f"user is not admin: {user}")
|
|
||||||
|
|
||||||
if "action" in message:
|
|
||||||
data = message["action"]
|
|
||||||
|
|
||||||
if user.last_message + USER_TIMEOUT > time.time():
|
|
||||||
logging.debug(f"dropping action: {data}")
|
|
||||||
return None
|
|
||||||
elif data in KEYS_ID:
|
|
||||||
r.incr(data)
|
|
||||||
user.last_message = time.time()
|
|
||||||
else:
|
|
||||||
logging.error(f"unsupported action: {data}")
|
|
||||||
|
|
||||||
|
|
||||||
async def handler(websocket, path: str):
|
async def handler(websocket: websockets.server.WebSocketServerProtocol, path: str):
|
||||||
"""Handle the messages sent by a user.
|
"""Handle the messages sent by a user.
|
||||||
|
|
||||||
Args:
|
Args:
|
||||||
websocket: the websocket used by the user.
|
websocket: the websocket used by the user.
|
||||||
path (str): the path used by the websocket. (?)
|
path (str): the path used by the websocket.
|
||||||
"""
|
"""
|
||||||
try:
|
# Register user
|
||||||
# Register user
|
user = User(websocket)
|
||||||
user = User(websocket)
|
USERS.register(user)
|
||||||
USERS.register(user)
|
logging.debug(f"registered user {user}")
|
||||||
# Manage received messages
|
|
||||||
async for json_message in websocket:
|
try: # Manage received messages
|
||||||
message: dict[str, str] = json.loads(json_message)
|
async for message in user.websocket:
|
||||||
await parse_message(user, message)
|
await parse_message(user, message)
|
||||||
|
except websockets.exceptions.ConnectionClosed:
|
||||||
|
logging.error(f"connection with user {user} is already closed")
|
||||||
|
except RuntimeError:
|
||||||
|
logging.error(f"two coroutines called recv() concurrently, user={user}")
|
||||||
finally:
|
finally:
|
||||||
# Unregister user
|
|
||||||
if user == USERS.admin:
|
|
||||||
USERS.admin = None
|
|
||||||
USERS.unregister(user)
|
USERS.unregister(user)
|
||||||
|
logging.debug(f"unregistered user {user}")
|
||||||
|
|
||||||
|
|
||||||
async def main():
|
async def main():
|
||||||
|
|
11
src/utils.py
11
src/utils.py
|
@ -1,15 +1,17 @@
|
||||||
import logging
|
import logging
|
||||||
import time
|
import time
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
from typing import Any, Optional
|
from typing import Any
|
||||||
|
|
||||||
|
import websockets.server
|
||||||
|
import websockets.typing
|
||||||
from mgba._pylib import ffi
|
from mgba._pylib import ffi
|
||||||
|
|
||||||
|
|
||||||
class User:
|
class User:
|
||||||
"""Store infos related to a connected user."""
|
"""Store infos related to a connected user."""
|
||||||
|
|
||||||
websocket: Any
|
websocket: websockets.server.WebSocketServerProtocol
|
||||||
last_message: float
|
last_message: float
|
||||||
|
|
||||||
def __init__(self, websocket: Any) -> None:
|
def __init__(self, websocket: Any) -> None:
|
||||||
|
@ -42,8 +44,6 @@ class User:
|
||||||
class Users(set):
|
class Users(set):
|
||||||
"""Store `User`s connected to the server."""
|
"""Store `User`s connected to the server."""
|
||||||
|
|
||||||
admin: Optional[User] = None
|
|
||||||
|
|
||||||
def register(self, user: User):
|
def register(self, user: User):
|
||||||
"""Register a user in the set.
|
"""Register a user in the set.
|
||||||
|
|
||||||
|
@ -73,7 +73,8 @@ async def save(core):
|
||||||
|
|
||||||
|
|
||||||
async def load(core, filename):
|
async def load(core, filename):
|
||||||
state = ffi.new("unsigned char[397312]") # pulled 397312 from my ass
|
state = ffi.new("unsigned char[397312]") # pulled 397312 straight from my ass
|
||||||
|
# TODO: checker les sources mgba pour savoir d'où sort 397312
|
||||||
with open(f"states/{filename}.state", "rb") as state_file:
|
with open(f"states/{filename}.state", "rb") as state_file:
|
||||||
for i in range(len(state)):
|
for i in range(len(state)):
|
||||||
state[i] = int.from_bytes(state_file.read(4), byteorder="big", signed=False)
|
state[i] = int.from_bytes(state_file.read(4), byteorder="big", signed=False)
|
||||||
|
|
Loading…
Reference in a new issue