1
0
Fork 0
mirror of https://github.com/cosmo-sims/cosmICweb-music.git synced 2024-09-19 16:53:43 +02:00

more typing

This commit is contained in:
Lukas Winkler 2024-04-20 23:20:54 +02:00
parent 41f2d058fb
commit d778a14651
Signed by: lukas
GPG key ID: 54DE4D798D244853
2 changed files with 23 additions and 67 deletions

View file

@ -1,3 +1,5 @@
from __future__ import annotations
import os import os
import sys import sys
import tempfile import tempfile
@ -26,7 +28,7 @@ EDITOR = os.environ.get("EDITOR", "vim")
EDITOR_IS_VIM = EDITOR in {"vim", "nvim"} EDITOR_IS_VIM = EDITOR in {"vim", "nvim"}
def query_yes_no(question, default="yes"): def query_yes_no(question: str, default="yes") -> bool:
"""Ask a yes/no question via raw_input() and return their answer. """Ask a yes/no question via raw_input() and return their answer.
"question" is a string that is presented to the user. "question" is a string that is presented to the user.
@ -58,7 +60,7 @@ def query_yes_no(question, default="yes"):
# Routines # Routines
def fetch_ellipsoids(url, api_token, attempts): def fetch_ellipsoids(url: str, api_token: str, attempts: int) -> list[Ellipsoid]:
for i in range(attempts): for i in range(attempts):
try: try:
r = requests.get(url, headers={"Authorization": "Token " + api_token}) r = requests.get(url, headers={"Authorization": "Token " + api_token})
@ -79,19 +81,21 @@ def fetch_ellipsoids(url, api_token, attempts):
for e in content for e in content
] ]
logging.error("Unable to download ellipsoids from {}".format(url)) logging.error("Unable to download ellipsoids from {}".format(url))
return None return []
def fetch_ellipsoid(url, api_token, traceback_radius, attempts=3): def fetch_ellipsoid(
url: str, api_token: str, traceback_radius, attempts: int = 3
) -> Ellipsoid | None:
ellipsoids = fetch_ellipsoids(url, api_token, attempts) ellipsoids = fetch_ellipsoids(url, api_token, attempts)
if ellipsoids is not None: if ellipsoids:
return next( return next(
(e for e in ellipsoids if e.traceback_radius == traceback_radius), None (e for e in ellipsoids if e.traceback_radius == traceback_radius), None
) )
return None return None
def fetch_downloadstore(cosmicweb_url, target): def fetch_downloadstore(cosmicweb_url: str, target: str) -> DownloadConfig:
try: try:
r = requests.get(cosmicweb_url + "/api/music/store/" + target) r = requests.get(cosmicweb_url + "/api/music/store/" + target)
# This will raise an error if not successful # This will raise an error if not successful
@ -122,7 +126,9 @@ def fetch_downloadstore(cosmicweb_url, target):
) )
def fetch_publication(cosmicweb_url, publication_name, traceback_radius): def fetch_publication(
cosmicweb_url: str, publication_name: str, traceback_radius
) -> DownloadConfig:
try: try:
r = requests.get(cosmicweb_url + "/api/publications/" + publication_name) r = requests.get(cosmicweb_url + "/api/publications/" + publication_name)
# This will raise an error if not successful # This will raise an error if not successful
@ -155,7 +161,7 @@ def fetch_publication(cosmicweb_url, publication_name, traceback_radius):
) )
def edit_template(template): def edit_template(template: str) -> str:
with tempfile.NamedTemporaryFile(suffix=".tmp.conf", mode="r+") as tf: with tempfile.NamedTemporaryFile(suffix=".tmp.conf", mode="r+") as tf:
tf.write(template) tf.write(template)
tf.flush() tf.flush()
@ -169,7 +175,7 @@ def edit_template(template):
return template return template
def apply_config_parameter(config: str, parameters: dict[str, Any]): def apply_config_parameter(config: str, parameters: dict[str, Any]) -> str:
new_lines = [] new_lines = []
for line in config.split("\n"): for line in config.split("\n"):
param = line.split("=")[0].strip() param = line.split("=")[0].strip()
@ -179,7 +185,7 @@ def apply_config_parameter(config: str, parameters: dict[str, Any]):
return "\n".join(new_lines) return "\n".join(new_lines)
def music_config_to_template(config: DownloadConfig): def music_config_to_template(config: DownloadConfig) -> str:
music_config = config.MUSIC music_config = config.MUSIC
settings = config.settings settings = config.settings
# TODO: apply output configuration # TODO: apply output configuration
@ -208,7 +214,7 @@ def compose_template(
config: DownloadConfig, config: DownloadConfig,
halo_name: str, halo_name: str,
halo_id: int, halo_id: int,
): ) -> str:
# TODO: add ellipsoid header (rtb, halo_name, etc) # TODO: add ellipsoid header (rtb, halo_name, etc)
shape_0 = ", ".join(str(e) for e in ellipsoid.shape[0]) shape_0 = ", ".join(str(e) for e in ellipsoid.shape[0])
shape_1 = ", ".join(str(e) for e in ellipsoid.shape[1]) shape_1 = ", ".join(str(e) for e in ellipsoid.shape[1])
@ -243,11 +249,11 @@ def write_music_file(output_file, music_config) -> None:
f.write(music_config) f.write(music_config)
def call_music(): def call_music() -> None:
pass pass
def process_config(config: DownloadConfig, args: Args): def process_config(config: DownloadConfig, args: Args) -> None:
ellipsoids = [] ellipsoids = []
for halo_name, url in zip(config.halo_names, config.halo_urls): for halo_name, url in zip(config.halo_names, config.halo_urls):
logging.info("Fetching ellipsoids from halo " + halo_name) logging.info("Fetching ellipsoids from halo " + halo_name)
@ -299,7 +305,7 @@ def process_config(config: DownloadConfig, args: Args):
# TODO: Execute MUSIC? # TODO: Execute MUSIC?
def downloadstore_mode(args: Args, target: str): def downloadstore_mode(args: Args, target: str) -> None:
logging.info("Fetching download configuration from the cosmICweb server") logging.info("Fetching download configuration from the cosmICweb server")
config = fetch_downloadstore(args.url, target) config = fetch_downloadstore(args.url, target)
if args.output_path == "./": if args.output_path == "./":
@ -309,7 +315,7 @@ def downloadstore_mode(args: Args, target: str):
process_config(config, args) process_config(config, args)
def publication_mode(args: Args, publication_name: str, traceback_radius: int): def publication_mode(args: Args, publication_name: str, traceback_radius: int) -> None:
logging.info( logging.info(
"Fetching publication " + publication_name + " from the cosmICweb server" "Fetching publication " + publication_name + " from the cosmICweb server"
) )
@ -320,7 +326,7 @@ def publication_mode(args: Args, publication_name: str, traceback_radius: int):
process_config(config, args) process_config(config, args)
def dir_path(p): def dir_path(p: str) -> str:
if os.path.isdir(p): if os.path.isdir(p):
return p return p
else: else:
@ -373,54 +379,3 @@ def publication(ctx, publication_name, traceback_radius):
if __name__ == "__main__": if __name__ == "__main__":
cli() cli()
# if __name__ == "__main__":
# parser = argparse.ArgumentParser()
# parser.add_argument(
# "--url",
# dest="cosmicweb_url",
# default=DEFAULT_URL,
# help="overwrite URL of the cosmicweb server",
# )
# parser.add_argument(
# "--output-path",
# type=dir_path,
# default="./",
# help="Download target for IC files. If downloading publication, will create a subfolder with the "
# "name of the publication",
# )
# parser.add_argument(
# "--common-directory", dest="create_subdirs", action="store_false"
# )
# parser.add_argument(
# "--attempts",
# type=int,
# default=3,
# help="number of attempts to download ellipsoids",
# )
# parser.add_argument("--verbose", action="store_true")
# subparsers = parser.add_subparsers(dest="mode")
# # Downloading from publications
# publication_parser = subparsers.add_parser(
# "publication", help="download publications"
# )
# publication_parser.add_argument("publication_name", help="name of the publication")
# publication_parser.add_argument(
# "--traceback_radius", type=int, choices=[1, 2, 4, 10], default=2, help=""
# )
# # Downloading from download object
# download_parser = subparsers.add_parser("get")
# download_parser.add_argument("target")
# args = parser.parse_args()
# if args.verbose:
# logger.setLevel("DEBUG")
# if args.mode == "get":
# downloadstore_mode(args)
# elif args.mode == "publication":
# publication_mode(args)
# else:
# raise NotImplementedError("unknown subparser")

View file

@ -1,5 +1,6 @@
from __future__ import annotations from __future__ import annotations
from datetime import datetime
from typing import NamedTuple, Any, List, Dict, TypedDict from typing import NamedTuple, Any, List, Dict, TypedDict