Improve the pyproject.toml configuration, add first test, and create more type hints
This commit is contained in:
parent
e1cf1cb194
commit
207178f822
1
.gitignore
vendored
1
.gitignore
vendored
@ -1,4 +1,5 @@
|
||||
__pycache__/
|
||||
.coverage
|
||||
dist/
|
||||
src/pogo_scaled_estimators/_version.py
|
||||
*.sqlite
|
||||
|
@ -27,7 +27,7 @@ classifiers = [
|
||||
]
|
||||
|
||||
[project.urls]
|
||||
"Source" = "https://git.blorp.dev/zo/pogo-scaled-estimators"
|
||||
Source = "https://git.blorp.dev/zo/pogo-scaled-estimators"
|
||||
|
||||
[project.scripts]
|
||||
ase-cli = "pogo_scaled_estimators:main_cli"
|
||||
@ -38,18 +38,46 @@ source = "vcs"
|
||||
[tool.hatch.build.hooks.vcs]
|
||||
version-file = "src/pogo_scaled_estimators/_version.py"
|
||||
|
||||
[tool.hatch.envs.lint]
|
||||
detached = true
|
||||
[tool.hatch.envs.default]
|
||||
dependencies = [
|
||||
"coverage[toml]",
|
||||
"pyright",
|
||||
"pytest",
|
||||
"requests-mock",
|
||||
"ruff",
|
||||
]
|
||||
|
||||
[tool.hatch.envs.lint.scripts]
|
||||
all = ["style", "typing"]
|
||||
format = ["ruff format --fix {args:.}"]
|
||||
style = ["ruff check {args:.}"]
|
||||
typing = ["pyright"]
|
||||
[tool.hatch.envs.default.scripts]
|
||||
test = "pytest {args:tests}"
|
||||
test-cov = "coverage run -m pytest {args:tests}"
|
||||
cov-report = [
|
||||
"- coverage combine",
|
||||
"coverage report",
|
||||
]
|
||||
cov = [
|
||||
"test-cov",
|
||||
"cov-report",
|
||||
]
|
||||
format-check = "ruff format --check --diff {args:.}"
|
||||
format-fix = "ruff format {args:.}"
|
||||
lint-check = "ruff check {args:.}"
|
||||
lint-fix = "ruff check --fix {args:.}"
|
||||
types-check = "pyright"
|
||||
|
||||
[tool.coverage.run]
|
||||
branch = true
|
||||
source_pkgs = ["pogo_scaled_estimators", "tests"]
|
||||
|
||||
[tool.coverage.paths]
|
||||
pogo_scaled_estimators = ["src/pogo_scaled_estimators", "*/pogo_scaled_estimators/src/pogo_scaled_estimators"]
|
||||
tests = ["tests", "*/pogo_scaled_estimators/tests"]
|
||||
|
||||
[tool.coverage.report]
|
||||
exclude_lines = [
|
||||
"no cov",
|
||||
"if __name__ == .__main__.:",
|
||||
"if TYPE_CHECKING:",
|
||||
]
|
||||
|
||||
[tool.ruff]
|
||||
target-version = "py312"
|
||||
@ -112,46 +140,18 @@ ignore = [
|
||||
]
|
||||
|
||||
[tool.pyright]
|
||||
include = ["src/glimmer", "tests"]
|
||||
exclude = ["**/__pycache__"]
|
||||
include = ["src/pogo_scaled_estimators", "tests"]
|
||||
|
||||
reportMissingImports = true
|
||||
reportMissingTypeStubs = false
|
||||
|
||||
pythonVersion = "3.12"
|
||||
pythonPlatform = "Linux"
|
||||
typeCheckingMode = "standard"
|
||||
strictListInference = true
|
||||
strictDictionaryInference = true
|
||||
strictSetInference = true
|
||||
reportAssertAlwaysTrue = "error"
|
||||
reportInvalidStringEscapeSequence = "error"
|
||||
reportSelfClsParameterName = "error"
|
||||
reportConstantRedefinition = "error"
|
||||
reportDeprecated = "error"
|
||||
reportDuplicateImport = "error"
|
||||
reportIncompatibleMethodOverride = "error"
|
||||
reportIncompatibleVariableOverride = "error"
|
||||
reportInconsistentConstructor = "error"
|
||||
reportMatchNotExhaustive = "warning"
|
||||
reportOverlappingOverload = "error"
|
||||
reportMissingSuperCall = "error"
|
||||
reportPrivateUsage = "warning"
|
||||
reportTypeCommentUsage = "error"
|
||||
reportUnnecessaryCast = "error"
|
||||
reportUnnecessaryComparison = "error"
|
||||
reportUnnecessaryContains = "error"
|
||||
reportUnnecessaryIsInstance = "error"
|
||||
reportUnusedClass = "warning"
|
||||
reportUnusedImport = "warning"
|
||||
reportUnusedFunction = "warning"
|
||||
reportUnusedVariable = "warning"
|
||||
reportUntypedBaseClass = "error"
|
||||
reportUntypedClassDecorator = "error"
|
||||
reportUntypedFunctionDecorator = "error"
|
||||
reportUntypedNamedTuple = "error"
|
||||
reportCallInDefaultInitializer = "error"
|
||||
reportImplicitOverride = "error"
|
||||
reportPropertyTypeMismatch = "warning"
|
||||
reportShadowedImports = "warning"
|
||||
reportUninitializedInstanceVariable = "warning"
|
||||
reportUnnecessaryTypeIgnoreComment = "warning"
|
||||
reportUnusedCallResult = "warning"
|
||||
typeCheckingMode = "strict"
|
||||
reportMissingParameterType = "none"
|
||||
reportUnknownArgumentType = "none"
|
||||
reportUnknownLambdaType = "none"
|
||||
reportUnknownMemberType = "none"
|
||||
reportUnknownParameterType = "none"
|
||||
reportUnknownVariableType = "none"
|
||||
reportUnusedFunction = "none"
|
||||
|
@ -6,11 +6,12 @@
|
||||
|
||||
import math
|
||||
from enum import Flag, auto
|
||||
from typing import final
|
||||
from typing import TypeGuard, final, get_args
|
||||
|
||||
from rich.progress import Progress
|
||||
|
||||
from pogo_scaled_estimators.pokebattler_proxy import MovesetResult, PokebattlerProxy, Raid
|
||||
from pogo_scaled_estimators.typing import PokemonType
|
||||
from pogo_scaled_estimators.utilities import format_move_name, format_pokemon_name
|
||||
|
||||
|
||||
@ -22,17 +23,25 @@ class Filter(Flag):
|
||||
DISALLOW_LEGENDARY_POKEMON = auto()
|
||||
|
||||
|
||||
class NotAPokemonTypeError(TypeError):
|
||||
def __init__(self, pokemon_types: list[str]):
|
||||
super().__init__(f"{pokemon_types} contains invalid Pokebattler type name")
|
||||
|
||||
|
||||
@final
|
||||
class Calculator:
|
||||
def __init__(self, attacker_types: list[str], filters: Filter = Filter.NO_FILTER) -> None:
|
||||
self.attacker_types = attacker_types
|
||||
if self._is_list_of_types(attacker_types):
|
||||
self.attacker_types: list[PokemonType] = attacker_types
|
||||
else:
|
||||
raise NotAPokemonTypeError(attacker_types)
|
||||
self.filters = filters
|
||||
self._pokebattler_proxy = PokebattlerProxy()
|
||||
self._progress = Progress()
|
||||
|
||||
def calculate(self, level: int = 40, party: int = 1) -> None:
|
||||
raid_bosses = self._pokebattler_proxy.raid_bosses(self.attacker_types)
|
||||
attackers = {
|
||||
attackers: dict[str, dict[str, list[MovesetResult]]] = {
|
||||
attacker: {"RAID_LEVEL_3": [], "RAID_LEVEL_5": [], "RAID_LEVEL_MEGA": []}
|
||||
for attacker in self._pokebattler_proxy.with_charged_moves(self.attacker_types)
|
||||
if self._allowed_attacker(attacker)
|
||||
@ -66,6 +75,8 @@ class Calculator:
|
||||
for attacker, movesets in results.items()
|
||||
if movesets
|
||||
}
|
||||
if not best_movesets:
|
||||
continue
|
||||
best_estimator = min(best_movesets.values(), key=lambda moveset: moveset.estimator).estimator
|
||||
for attacker, moveset in best_movesets.items():
|
||||
attackers[attacker][simplified_raid_tier].append(moveset.scale(best_estimator))
|
||||
@ -86,6 +97,9 @@ class Calculator:
|
||||
f"[bold]{format_pokemon_name(attacker, attacker_type)}[/bold] ({format_move_name(fast_move, fast_move_type)}/{format_move_name(charged_move, charged_move_type)}): {ase:.2f}"
|
||||
)
|
||||
|
||||
def _is_list_of_types(self, attacker_types: list[str]) -> TypeGuard[list[PokemonType]]:
|
||||
return all(attacker_type in get_args(PokemonType) for attacker_type in attacker_types)
|
||||
|
||||
def _allowed_attacker(self, pokemon_id: str) -> bool:
|
||||
if Filter.DISALLOW_MEGA_POKEMON in self.filters and ("_MEGA" in pokemon_id or "_PRIMAL" in pokemon_id):
|
||||
return False
|
||||
@ -121,17 +135,19 @@ class Calculator:
|
||||
return ("MOVE_NONE", "MOVE_NONE")
|
||||
return min(movesets, key=lambda moveset: self._ase(moveset_results_by_tier, only=moveset))
|
||||
|
||||
def _ase(self, moveset_results_by_tier: dict[str, list[MovesetResult]], only=None) -> float:
|
||||
try:
|
||||
return (
|
||||
0.15 * self._average_estimator(moveset_results_by_tier["RAID_LEVEL_3"], only)
|
||||
+ 0.50 * self._average_estimator(moveset_results_by_tier["RAID_LEVEL_5"], only)
|
||||
+ 0.35 * self._average_estimator(moveset_results_by_tier["RAID_LEVEL_MEGA"], only)
|
||||
)
|
||||
except ZeroDivisionError:
|
||||
return float("inf")
|
||||
def _ase(
|
||||
self, moveset_results_by_tier: dict[str, list[MovesetResult]], only: tuple[str, str] | None = None
|
||||
) -> float:
|
||||
return (
|
||||
0.15 * self._average_estimator(moveset_results_by_tier["RAID_LEVEL_3"], only)
|
||||
+ 0.50 * self._average_estimator(moveset_results_by_tier["RAID_LEVEL_5"], only)
|
||||
+ 0.35 * self._average_estimator(moveset_results_by_tier["RAID_LEVEL_MEGA"], only)
|
||||
)
|
||||
|
||||
def _average_estimator(self, moveset_results: list[MovesetResult], only: tuple[str, str] | None = None) -> float:
|
||||
if only:
|
||||
moveset_results = [m for m in moveset_results if m.fast_move == only[0] and m.charged_move == only[1]]
|
||||
return sum(moveset.estimator for moveset in moveset_results) / len(moveset_results)
|
||||
try:
|
||||
return sum(moveset.estimator for moveset in moveset_results) / len(moveset_results)
|
||||
except ZeroDivisionError:
|
||||
return float("inf")
|
||||
|
@ -5,31 +5,47 @@
|
||||
# https://opensource.org/licenses/MIT.
|
||||
|
||||
import argparse
|
||||
import logging
|
||||
import operator
|
||||
import sys
|
||||
from functools import reduce
|
||||
|
||||
from rich.logging import RichHandler
|
||||
|
||||
from pogo_scaled_estimators._version import version
|
||||
from pogo_scaled_estimators.calculator import Calculator, Filter
|
||||
|
||||
|
||||
def main_cli():
|
||||
parser = argparse.ArgumentParser()
|
||||
_ = parser.add_argument("type", nargs="+", help="an attacker type")
|
||||
_ = parser.add_argument("--level", type=int, default=40)
|
||||
_ = parser.add_argument("--party", type=int, default=1)
|
||||
_ = parser.add_argument("--no-legacy", dest="filters", action="append_const", const=Filter.DISALLOW_LEGACY_MOVES)
|
||||
_ = parser.add_argument("--no-mega", dest="filters", action="append_const", const=Filter.DISALLOW_MEGA_POKEMON)
|
||||
_ = parser.add_argument("--no-shadow", dest="filters", action="append_const", const=Filter.DISALLOW_SHADOW_POKEMON)
|
||||
_ = parser.add_argument(
|
||||
parser.add_argument("type", nargs="+", help="an attacker type")
|
||||
parser.add_argument("--level", type=int, default=40)
|
||||
parser.add_argument("--party", type=int, default=1)
|
||||
parser.add_argument("--no-legacy", dest="filters", action="append_const", const=Filter.DISALLOW_LEGACY_MOVES)
|
||||
parser.add_argument("--no-mega", dest="filters", action="append_const", const=Filter.DISALLOW_MEGA_POKEMON)
|
||||
parser.add_argument("--no-shadow", dest="filters", action="append_const", const=Filter.DISALLOW_SHADOW_POKEMON)
|
||||
parser.add_argument(
|
||||
"--no-legendary", dest="filters", action="append_const", const=Filter.DISALLOW_LEGENDARY_POKEMON
|
||||
)
|
||||
_ = parser.add_argument("--version", action="version", version=version)
|
||||
parser.add_argument(
|
||||
"-v", "--verbose", action="store_const", dest="log_level", const=logging.DEBUG, default=logging.WARNING
|
||||
)
|
||||
parser.add_argument("--version", action="version", version=version)
|
||||
|
||||
args = parser.parse_args()
|
||||
|
||||
logging.basicConfig(level=args.log_level, handlers=[RichHandler()])
|
||||
log = logging.getLogger()
|
||||
requests_log = logging.getLogger("urllib3")
|
||||
requests_log.setLevel(args.log_level)
|
||||
requests_log.propagate = True
|
||||
|
||||
filters = reduce(operator.or_, args.filters or [], Filter.NO_FILTER)
|
||||
calculator = Calculator(args.type, filters)
|
||||
calculator.calculate(level=args.level, party=args.party)
|
||||
try:
|
||||
calculator = Calculator(args.type, filters)
|
||||
calculator.calculate(level=args.level, party=args.party)
|
||||
except Exception:
|
||||
log.exception("Could not calculate ASE.")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
|
@ -7,10 +7,12 @@
|
||||
import urllib.parse
|
||||
from dataclasses import dataclass
|
||||
from functools import cached_property
|
||||
from typing import cast, final
|
||||
from typing import NotRequired, TypedDict, cast, final
|
||||
|
||||
import requests_cache
|
||||
|
||||
from pogo_scaled_estimators.typing import PokemonType
|
||||
|
||||
BASE_URL = "https://fight.pokebattler.com"
|
||||
WEAKNESS = 1.6
|
||||
DOUBLE_WEAKNESS = WEAKNESS * WEAKNESS
|
||||
@ -27,7 +29,7 @@ class Raid:
|
||||
@dataclass
|
||||
class Move:
|
||||
move_id: str
|
||||
typing: str
|
||||
typing: PokemonType
|
||||
|
||||
|
||||
@dataclass
|
||||
@ -36,37 +38,58 @@ class MovesetResult:
|
||||
charged_move: str
|
||||
estimator: float
|
||||
|
||||
def scale(self, factor: float):
|
||||
def scale(self, factor: float) -> "MovesetResult":
|
||||
return MovesetResult(self.fast_move, self.charged_move, self.estimator / factor)
|
||||
|
||||
|
||||
class PokebattlerMove(TypedDict):
|
||||
moveId: str
|
||||
type: PokemonType
|
||||
|
||||
|
||||
class PokebattlerPokemon(TypedDict):
|
||||
pokemonId: str
|
||||
type: PokemonType
|
||||
type2: NotRequired[PokemonType]
|
||||
quickMoves: list[str]
|
||||
cinematicMoves: list[str]
|
||||
movesets: list[dict[str, str]]
|
||||
|
||||
|
||||
class PokebattlerRaid(TypedDict):
|
||||
pokemonId: str
|
||||
|
||||
|
||||
class PokebattlerRaidTierInfo(TypedDict):
|
||||
guessTier: str
|
||||
|
||||
|
||||
class PokebattlerRaidTier(TypedDict):
|
||||
tier: str
|
||||
info: PokebattlerRaidTierInfo
|
||||
raids: list[PokebattlerRaid]
|
||||
|
||||
|
||||
@final
|
||||
class PokebattlerProxy:
|
||||
def __init__(self):
|
||||
self._cached_session = requests_cache.CachedSession("pokebatter_cache", cache_control=True)
|
||||
self._pokemon: dict | None = None
|
||||
self._raids: dict | None = None
|
||||
self._resists: dict | None = None
|
||||
|
||||
@property
|
||||
def cached_session(self):
|
||||
return self._cached_session
|
||||
def __init__(self, log_level="INFO"):
|
||||
self._cached_session = requests_cache.CachedSession("pokebatter_cache", cache_control=True, use_cache_dir=True)
|
||||
|
||||
@cached_property
|
||||
def moves(self) -> dict:
|
||||
return self.cached_session.get(f"{BASE_URL}/moves").json()["move"]
|
||||
def moves(self) -> list[PokebattlerMove]:
|
||||
return self._cached_session.get(f"{BASE_URL}/moves").json()["move"]
|
||||
|
||||
@cached_property
|
||||
def pokemon(self) -> dict:
|
||||
return self.cached_session.get(f"{BASE_URL}/pokemon").json()["pokemon"]
|
||||
def pokemon(self) -> list[PokebattlerPokemon]:
|
||||
return self._cached_session.get(f"{BASE_URL}/pokemon").json()["pokemon"]
|
||||
|
||||
@cached_property
|
||||
def raids(self) -> dict:
|
||||
return self.cached_session.get(f"{BASE_URL}/raids").json()
|
||||
def raids(self) -> list[PokebattlerRaidTier]:
|
||||
return self._cached_session.get(f"{BASE_URL}/raids").json()["tiers"]
|
||||
|
||||
@cached_property
|
||||
def resists(self) -> dict:
|
||||
return self.cached_session.get(f"{BASE_URL}/resists").json()
|
||||
def resists(self) -> dict[str, list[float]]:
|
||||
return self._cached_session.get(f"{BASE_URL}/resists").json()
|
||||
|
||||
def simulate(self, raid: Raid) -> dict[str, list[MovesetResult]]:
|
||||
query_string = {
|
||||
@ -85,7 +108,8 @@ class PokebattlerProxy:
|
||||
url = f"{BASE_URL}/raids/defenders/{raid.defender}/levels/{raid.tier}/attackers/levels/{raid.level}/strategies/CINEMATIC_ATTACK_WHEN_POSSIBLE/DEFENSE_RANDOM_MC?{urllib.parse.urlencode(query_string, doseq=True)}"
|
||||
response = self._cached_session.get(url)
|
||||
results: dict[str, list[MovesetResult]] = {}
|
||||
for attacker in response.json()["attackers"][0]["randomMove"]["defenders"]:
|
||||
response_json = response.json()
|
||||
for attacker in response_json["attackers"][0]["randomMove"]["defenders"]:
|
||||
results[attacker["pokemonId"]] = [
|
||||
MovesetResult(
|
||||
attacker_moves["move1"], attacker_moves["move2"], cast(float, attacker_moves["result"]["estimator"])
|
||||
@ -94,9 +118,9 @@ class PokebattlerProxy:
|
||||
]
|
||||
return results
|
||||
|
||||
def raid_bosses(self, attacker_types: list[str]) -> dict:
|
||||
raid_tiers = []
|
||||
raid_bosses = {}
|
||||
def raid_bosses(self, attacker_types: list[PokemonType]) -> dict[str, list[str]]:
|
||||
raid_tiers: list[str] = []
|
||||
raid_bosses: dict[str, list[str]] = {}
|
||||
|
||||
for raid_level in ["3", "5", "MEGA", "MEGA_5", "ULTRA_BEAST"]:
|
||||
tier = f"RAID_LEVEL_{raid_level}"
|
||||
@ -109,21 +133,12 @@ class PokebattlerProxy:
|
||||
)
|
||||
raid_bosses[tier] = []
|
||||
|
||||
for tier in filter(lambda tier: tier["tier"] in raid_tiers, self.raids["tiers"]):
|
||||
for boss in (raid["pokemon"] for raid in tier["raids"]):
|
||||
for tier in filter(lambda tier: tier["tier"] in raid_tiers, self.raids):
|
||||
for boss in (raid["pokemonId"] for raid in tier["raids"]):
|
||||
if boss.endswith("_FORM"):
|
||||
continue
|
||||
boss_pokemon = next(filter(lambda mon: mon["pokemonId"] == boss, self.pokemon))
|
||||
if ("candyToEvolve" in boss_pokemon or boss in ["SEADRA", "SEALEO"]) and boss not in [
|
||||
"KELDEO",
|
||||
"LUMINEON",
|
||||
"MANAPHY",
|
||||
"PHIONE",
|
||||
"STUNFISK",
|
||||
"TERRAKION",
|
||||
]:
|
||||
continue
|
||||
boss_types = (
|
||||
boss_pokemon: PokebattlerPokemon = next(filter(lambda mon: mon["pokemonId"] == boss, self.pokemon))
|
||||
boss_types: tuple[PokemonType, PokemonType] = (
|
||||
boss_pokemon["type"],
|
||||
boss_pokemon.get("type2", "POKEMON_TYPE_NONE"),
|
||||
)
|
||||
@ -132,7 +147,7 @@ class PokebattlerProxy:
|
||||
|
||||
return raid_bosses
|
||||
|
||||
def _is_weak(self, attacker_type: str, defender_types: tuple[str, str]) -> bool:
|
||||
def _is_weak(self, attacker_type: PokemonType, defender_types: tuple[PokemonType, PokemonType]) -> bool:
|
||||
pokemon_types = list(self.resists.keys())
|
||||
defender_type_indices = (
|
||||
pokemon_types.index(defender_types[0]),
|
||||
@ -156,8 +171,8 @@ class PokebattlerProxy:
|
||||
|
||||
return False
|
||||
|
||||
def with_charged_moves(self, attacker_types: list[str]) -> list[str]:
|
||||
charged_moves = [
|
||||
def with_charged_moves(self, attacker_types: list[PokemonType]) -> list[str]:
|
||||
charged_moves: list[str] = [
|
||||
move["moveId"]
|
||||
for move in self.moves
|
||||
if "moveId" in move and "type" in move and move["type"] in attacker_types
|
||||
@ -168,10 +183,10 @@ class PokebattlerProxy:
|
||||
if any(moveset["cinematicMove"] in charged_moves for moveset in mon["movesets"])
|
||||
]
|
||||
|
||||
def find_pokemon(self, name: str) -> dict:
|
||||
def find_pokemon(self, name: str) -> PokebattlerPokemon:
|
||||
return next(filter(lambda mon: mon["pokemonId"] == name, self.pokemon))
|
||||
|
||||
def pokemon_type(self, name: str) -> str:
|
||||
def pokemon_type(self, name: str) -> PokemonType:
|
||||
return self.find_pokemon(name)["type"]
|
||||
|
||||
def find_move(self, move_id: str) -> Move:
|
||||
|
29
src/pogo_scaled_estimators/typing.py
Normal file
29
src/pogo_scaled_estimators/typing.py
Normal file
@ -0,0 +1,29 @@
|
||||
# Copyright 2024 Zoé Cassiopée Gauthier.
|
||||
#
|
||||
# Use of this source code is governed by an MIT-style
|
||||
# license that can be found in the LICENSE file or at
|
||||
# https://opensource.org/licenses/MIT.
|
||||
|
||||
from typing import Literal
|
||||
|
||||
PokemonType = Literal[
|
||||
"POKEMON_TYPE_NONE",
|
||||
"POKEMON_TYPE_BUG",
|
||||
"POKEMON_TYPE_DARK",
|
||||
"POKEMON_TYPE_DRAGON",
|
||||
"POKEMON_TYPE_ELECTRIC",
|
||||
"POKEMON_TYPE_FAIRY",
|
||||
"POKEMON_TYPE_FIGHTING",
|
||||
"POKEMON_TYPE_FIRE",
|
||||
"POKEMON_TYPE_FLYING",
|
||||
"POKEMON_TYPE_GHOST",
|
||||
"POKEMON_TYPE_GRASS",
|
||||
"POKEMON_TYPE_GROUND",
|
||||
"POKEMON_TYPE_ICE",
|
||||
"POKEMON_TYPE_NORMAL",
|
||||
"POKEMON_TYPE_POISON",
|
||||
"POKEMON_TYPE_PSYCHIC",
|
||||
"POKEMON_TYPE_ROCK",
|
||||
"POKEMON_TYPE_STEEL",
|
||||
"POKEMON_TYPE_WATER",
|
||||
]
|
@ -4,7 +4,11 @@
|
||||
# license that can be found in the LICENSE file or at
|
||||
# https://opensource.org/licenses/MIT.
|
||||
|
||||
POKEMON_TYPE_COLORS = {
|
||||
from typing import Final
|
||||
|
||||
from pogo_scaled_estimators.typing import PokemonType
|
||||
|
||||
POKEMON_TYPE_COLORS: Final[dict[PokemonType, str]] = {
|
||||
"POKEMON_TYPE_BUG": "green_yellow",
|
||||
"POKEMON_TYPE_DARK": "bright_black",
|
||||
"POKEMON_TYPE_DRAGON": "dodger_blue2",
|
||||
@ -27,7 +31,7 @@ POKEMON_TYPE_COLORS = {
|
||||
MINIMUM_SPECIAL_NAME_PARTS = 2
|
||||
|
||||
|
||||
def format_pokemon_name(name: str, pokemon_type: str | None = None):
|
||||
def format_pokemon_name(name: str, pokemon_type: PokemonType | None = None):
|
||||
parts = [part.capitalize() for part in name.split("_")]
|
||||
if parts[-1] == "Mega" or parts[-1] == "Primal":
|
||||
parts = [parts[-1]] + parts[:-1]
|
||||
@ -41,7 +45,7 @@ def format_pokemon_name(name: str, pokemon_type: str | None = None):
|
||||
return formatted_name
|
||||
|
||||
|
||||
def format_move_name(name, move_type: str | None = None):
|
||||
def format_move_name(name: str, move_type: PokemonType | None = None):
|
||||
parts = [part.capitalize() for part in name.split("_")]
|
||||
if parts[-1] == "Fast":
|
||||
parts = parts[:-1]
|
||||
|
5
tests/__init__.py
Normal file
5
tests/__init__.py
Normal file
@ -0,0 +1,5 @@
|
||||
# Copyright 2024 Zoé Cassiopée Gauthier.
|
||||
#
|
||||
# Use of this source code is governed by an MIT-style
|
||||
# license that can be found in the LICENSE file or at
|
||||
# https://opensource.org/licenses/MIT.
|
55
tests/test_pokebattler_proxy.py
Normal file
55
tests/test_pokebattler_proxy.py
Normal file
@ -0,0 +1,55 @@
|
||||
# Copyright 2024 Zoé Cassiopée Gauthier.
|
||||
#
|
||||
# Use of this source code is governed by an MIT-style
|
||||
# license that can be found in the LICENSE file or at
|
||||
# https://opensource.org/licenses/MIT.
|
||||
|
||||
import unittest.mock
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from pogo_scaled_estimators.pokebattler_proxy import MovesetResult, PokebattlerProxy, Raid
|
||||
|
||||
|
||||
class MockSession(requests.Session):
|
||||
def __init__(self, *args, **kwargs):
|
||||
kwargs.pop("cache_control", None)
|
||||
kwargs.pop("use_cache_dir", None)
|
||||
super().__init__(*[], **kwargs)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def _disable_requests_cache():
|
||||
"""Replace CachedSession with a regular Session for all test functions"""
|
||||
with unittest.mock.patch("requests_cache.CachedSession", MockSession):
|
||||
yield
|
||||
|
||||
|
||||
def test_simulate(requests_mock):
|
||||
requests_mock.get(
|
||||
"https://fight.pokebattler.com/raids/defenders/MEWTWO_SHADOW_FORM/levels/RAID_LEVEL_5_SHADOW/attackers/levels/40/strategies/CINEMATIC_ATTACK_WHEN_POSSIBLE/DEFENSE_RANDOM_MC",
|
||||
json={
|
||||
"attackers": [
|
||||
{
|
||||
"randomMove": {
|
||||
"move1": "RANDOM",
|
||||
"move2": "RANDOM",
|
||||
"defenders": [
|
||||
{
|
||||
"pokemonId": "BIDOOF",
|
||||
"byMove": [
|
||||
{"move1": "TACKLE_FAST", "move2": "HYPER_FANG", "result": {"estimator": 5.4321}},
|
||||
],
|
||||
"cp": 721,
|
||||
}
|
||||
],
|
||||
}
|
||||
}
|
||||
]
|
||||
},
|
||||
)
|
||||
|
||||
pokebattler_proxy = PokebattlerProxy()
|
||||
raid = Raid("RAID_LEVEL_5_SHADOW", "MEWTWO_SHADOW_FORM")
|
||||
results = pokebattler_proxy.simulate(raid)
|
||||
assert results == {"BIDOOF": [MovesetResult("TACKLE_FAST", "HYPER_FANG", 5.4321)]}
|
Loading…
Reference in New Issue
Block a user