Skip to content

Cache fails to be displayed in notebook if using the path syntax and printing it prints the full cache #509

@Butanium

Description

@Butanium
from nnsight import LanguageModel
gpt2 = LanguageModel("gpt2", rename={".transformer.h": "layers"})

with gpt2.trace("Hello") as tracer:
    cache = tracer.cache(modules=[gpt2.layers[0]]).save()

print(cache.keys())
print(f"type: {type(cache['model.transformer.h.0'])}")
print(f"type: {type(cache.model.transformer.h[0])}")
display(f"value: {cache['model.transformer.h.0']}")
display(f"value: {cache.model.transformer.h[0]}")
display(cache["model.transformer.h.0"])
display(cache.model.transformer.h[0])  # fails

Error:

---------------------------------------------------------------------------
KeyError                                  Traceback (most recent call last)
File ~/projects/nnterp/.venv/lib/python3.10/site-packages/IPython/core/formatters.py:770, in PlainTextFormatter.__call__(self, obj)
    763 stream = StringIO()
    764 printer = pretty.RepresentationPrinter(stream, self.verbose,
    765     self.max_width, self.newline,
    766     max_seq_length=self.max_seq_length,
    767     singleton_pprinters=self.singleton_printers,
    768     type_pprinters=self.type_printers,
    769     deferred_pprinters=self.deferred_printers)
--> 770 printer.pretty(obj)
    771 printer.flush()
    772 return stream.getvalue()

File ~/projects/nnterp/.venv/lib/python3.10/site-packages/IPython/lib/pretty.py:394, in RepresentationPrinter.pretty(self, obj)
    391 for cls in _get_mro(obj_class):
    392     if cls in self.type_pprinters:
    393         # printer registered in self.type_pprinters
--> 394         return self.type_pprinters[cls](obj, self, cycle)
    395     else:
    396         # deferred printer
    397         printer = self._in_deferred_types(cls)

File ~/projects/nnterp/.venv/lib/python3.10/site-packages/IPython/lib/pretty.py:701, in _dict_pprinter_factory.<locals>.inner(obj, p, cycle)
    699     p.pretty(key)
    700     p.text(': ')
--> 701     p.pretty(obj[key])
    702 p.end_group(step, end)

File ~/projects/nnterp/.venv/lib/python3.10/site-packages/nnsight/intervention/tracing/tracer.py:136, in Cache.CacheDict.__getitem__(self, key)
    133     name = self._alias_paths.get(name, name)
    135     path = self._path + "." + name if self._path != "" else name
--> 136     return dict.__getitem__(self, path)
    138 if isinstance(name, int):
    139     path = self._path + "." + f"{name}"

KeyError: 'model.transformer.h.0.model.transformer.h.0'

Metadata

Metadata

Assignees

No one assigned

    Labels

    bugSomething isn't workingcache

    Type

    No type

    Projects

    No projects

    Milestone

    No milestone

    Relationships

    None yet

    Development

    No branches or pull requests

    Issue actions