better environment observation

main
Harald Holtmann 2025-09-06 20:23:15 +02:00
parent 341cd92a33
commit 5d2d39586e
2 changed files with 91 additions and 52 deletions

@ -1,6 +1,7 @@
import json import json
import logging import os
import os.path import os.path
import random
from argparse import Namespace from argparse import Namespace
from typing import List from typing import List
@ -29,7 +30,32 @@ def select(problem: str):
return response.json()["problemName"] return response.json()["problemName"]
explore_cache = {"seed": random.getrandbits(30)}
try:
with open("explore_cache.json") as h:
explore_cache = json.loads(h.read())
except OSError:
pass
random.seed(explore_cache["seed"])
def write_explore_cache():
with open("explore_cache.json", "w") as h:
h.write(json.dumps(explore_cache))
def clean_explore_cache():
try:
os.remove("explore_cache.json")
except OSError:
pass
def explore(plans: List[str]): def explore(plans: List[str]):
cache_key = ",".join(plans)
if cache_key in explore_cache:
return explore_cache[cache_key]
response = requests.post( response = requests.post(
config.contest_url + "/explore", config.contest_url + "/explore",
json={"plans": plans, "id": config.id}, json={"plans": plans, "id": config.id},
@ -37,7 +63,9 @@ def explore(plans: List[str]):
if not response.ok: if not response.ok:
raise APIError(f"{response.status_code}: {response.text}") raise APIError(f"{response.status_code}: {response.text}")
return response.json() data = response.json()
explore_cache[cache_key] = data
return data
def guess(layout): def guess(layout):

@ -85,25 +85,6 @@ class Explore:
return path return path
def save(self):
return {
"problem": self.problem,
"room": dict(self.rooms),
"room_ids": dict(self.room_ids),
"neighbours": dict(self.neighbors),
"unification_id": dict(self.unification_id),
}
@classmethod
def load(cls, obj):
new = cls(obj["problem"])
new.problem = obj["problem"]
new.rooms.update(obj["room"])
new.room_ids.update(obj["room_ids"])
new.neighbors.update(obj["neighbours"])
new.unification_id.update(obj["unification_id"])
return new
def explore(self, path=Path(), probes=None): def explore(self, path=Path(), probes=None):
probes = probes or self.probes probes = probes or self.probes
@ -120,13 +101,16 @@ class Explore:
def _add_room(self, path, results): def _add_room(self, path, results):
label = results[0][0] label = results[0][0]
assert all(result[0] == label for result in results)
probe_ids = [r for res in (rs[1:] for rs in results) for r in res] probe_ids = [r for res in (rs[1:] for rs in results) for r in res]
room_id = str(label) + "".join(str(p) for p in probe_ids) room_id = str(label) + "".join(str(p) for p in probe_ids)
# print("add room", path, results, room_id) # print("add room", path, results, room_id)
if path in self.rooms: if path in self.rooms:
rid = self.rooms[path] rid = self.rooms[path]
assert rid == room_id, f"expected match room at {path}: {rid} != {room_id}" if rid != room_id:
raise ExploreError(f"expected match room at {path}: {rid} != {room_id}")
else: else:
self.rooms[path] = room_id self.rooms[path] = room_id
self.room_ids[room_id].add(path) self.room_ids[room_id].add(path)
@ -147,14 +131,16 @@ class Explore:
pl, d = path.last() pl, d = path.last()
pl = self._path(pl) pl = self._path(pl)
if pl in self.rooms: if pl in self.rooms:
print("penult", path, pl, d, pl in self.rooms) # print("penult", path, pl, d, pl in self.rooms)
self.dump() # self.dump()
p, rid = self.neighbors[pl][d] p, rid = self.neighbors[pl][d]
assert rid is None or rid == room_id, f"penultimate {path} {pl}: {rid} != {room_id}" assert (
rid is None or rid == room_id
), f"penultimate {path} {pl}: {rid} != {room_id}"
self.neighbors[pl][d] = (p, room_id) self.neighbors[pl][d] = (p, room_id)
self.dump() # self.dump()
def update_path(self, path, door, result0, result1): def update_path(self, path, door, result0, result1):
path = self._path(path) path = self._path(path)
@ -186,21 +172,17 @@ class Explore:
try to unify rooms at paths p1, p2 try to unify rooms at paths p1, p2
return unified rooms return unified rooms
""" """
path1 = self._path(path1)
path2 = self._path(path2)
print("unify", path1, path2) print("unify", path1, path2)
if path1 == path2: if path1 == path2:
return return
if path1 not in self.rooms: assert path1 in self.rooms, f"room '{path1}' not explored"
raise ExploreError(f"room '{path1}' not explored") assert path2 in self.rooms, f"room '{path2}' not explored"
if path2 not in self.rooms:
raise ExploreError(f"room '{path2}' not explored")
if self.rooms[path1] != self.rooms[path2]: if self.rooms[path1] != self.rooms[path2]:
raise ExploreError(f"ids of '{path1}'({self.rooms[path1]}) and '{path2}'({self.rooms[path2]}) do not match") raise ExploreError(
f"ids of '{path1}'({self.rooms[path1]}) and '{path2}'({self.rooms[path2]}) do not match"
)
path = min(path1, path2) path = min(path1, path2)
pmerge = max(path1, path2) pmerge = max(path1, path2)
@ -209,10 +191,17 @@ class Explore:
self.unification_id[pmerge] = path self.unification_id[pmerge] = path
merged_neighbors = [] merged_neighbors = []
for n, ((p, rid), (pm, rmid)) in enumerate(zip(self.neighbors[path], self.neighbors[pmerge])): for n, ((p, rid), (pm, rmid)) in enumerate(
zip(self.neighbors[path], self.neighbors[pmerge])
):
if rid and rmid and rid != rmid: if rid and rmid and rid != rmid:
raise ExploreError(f"neighbor {n} of '{path}'({rid}) and '{pmerge}'({rmid}) do not match") raise ExploreError(
merged_neighbors.append((self._path(p), rid or rmid)) f"neighbor {n} of '{path}'({rid}) and '{pmerge}'({rmid}) do not match"
)
if rmid:
merged_neighbors.append((self._path(pm), rmid))
else:
merged_neighbors.append((self._path(p), rid))
# fix rooms # fix rooms
del self.rooms[pmerge] del self.rooms[pmerge]
@ -239,7 +228,9 @@ class Explore:
new.append((np_, rid)) new.append((np_, rid))
if rid: if rid:
try: try:
assert np_ in self.rooms, f"unify: path {np} {np_} of {(p, ns)} not in rooms" assert (
np_ in self.rooms
), f"unify: path {np} {np_} of {(p, ns)} not in rooms"
except AssertionError as exc: except AssertionError as exc:
self.dump() self.dump()
raise exc raise exc
@ -265,7 +256,9 @@ class Explore:
yield d, path yield d, path
def is_explored(self): def is_explored(self):
return next(self.unexplored(), None) is None and all(len(p) == 1 for p in self.room_ids.values()) return next(self.unexplored(), None) is None and all(
len(p) == 1 for p in self.room_ids.values()
)
def guess(self): def guess(self):
ids = {} ids = {}
@ -278,7 +271,18 @@ class Explore:
src_id = self.rooms[path] src_id = self.rooms[path]
for src_door, (trg_path, trg_id) in enumerate(ns): for src_door, (trg_path, trg_id) in enumerate(ns):
src = (src_id, src_door) src = (src_id, src_door)
trg_door = next(j for j, (p, rid) in enumerate(self.neighbors[trg_path]) if rid == src_id) trg_door = next(
(
j
for j, (p, rid) in enumerate(self.neighbors[trg_path])
if rid == src_id
),
None,
)
if trg_door is None:
raise ExploreError(
f"backlink not found: {(src, trg_path, self.neighbors[trg_path])}"
)
trg = (trg_id, trg_door) trg = (trg_id, trg_door)
if (src, trg) in connected or (trg, src) in connected: if (src, trg) in connected or (trg, src) in connected:
@ -301,8 +305,9 @@ class Explore:
return api.guess(layout) return api.guess(layout)
def room_solve(problem): def room_solve(problem, nrooms, plen):
ex = Explore(problem, DOORS) ex = Explore(problem, [d+d+d+d+d+d for d in DOORS[:plen]])
print(ex.probes)
api.select(ex.problem) api.select(ex.problem)
res = ex.explore(Path()) res = ex.explore(Path())
@ -310,16 +315,16 @@ def room_solve(problem):
ex.dump() ex.dump()
while True: while True:
unexplored = next(ex.unexplored(), None) door, unexplored = next(ex.unexplored(), (None, None))
if not unexplored: if unexplored is None:
break break
print("explore", unexplored) print("explore", door, unexplored)
res = ex.explore(unexplored) path = unexplored + Path([door])
ex.update(unexplored, res) res = ex.explore(path)
ex.update(path, res)
ex.dump() ex.dump()
print("unify")
ex.unify_all() ex.unify_all()
ex.dump() ex.dump()
@ -384,8 +389,6 @@ def path_solve(problem, nrooms, plen):
if __name__ == "__main__": if __name__ == "__main__":
# with open("test.rooms") as h:
# obj = ast.literal_eval(h.read())
problem = sys.argv[1] problem = sys.argv[1]
with open(os.path.join("..", "problems.json")) as h: with open(os.path.join("..", "problems.json")) as h:
@ -396,4 +399,12 @@ if __name__ == "__main__":
if problem not in problems: if problem not in problems:
raise ExploreError(f"unknown problem {problem}") raise ExploreError(f"unknown problem {problem}")
path_solve(problem, problems[problem]["size"], int(sys.argv[2])) try:
room_solve(problem, problems[problem]["size"], int(sys.argv[2]))
api.clean_explore_cache()
# except ExploreError as exc:
# api.clean_explore_cache()
# raise exc
except Exception as exc:
api.write_explore_cache()
raise exc