Skip to content

Reject raw callables in config during deserialization#22736

Open
Saumay wants to merge 1 commit intokeras-team:masterfrom
Saumay:fix-deserialize-raw-callables
Open

Reject raw callables in config during deserialization#22736
Saumay wants to merge 1 commit intokeras-team:masterfrom
Saumay:fix-deserialize-raw-callables

Conversation

@Saumay
Copy link
Copy Markdown
Contributor

@Saumay Saumay commented Apr 21, 2026

Description

Fixes #22705.

deserialize_keras_object silently accepted raw Python callables embedded in a config dict (the issue's repro nests one under config["config"]). The callable flowed through cls.from_config(inner_config), got stored on the reconstructed object, and was invoked later (e.g. during build()), with no validation or warning.

This PR walks the outer config recursively and rejects raw callables (types.FunctionType, BuiltinFunctionType, MethodType, BuiltinMethodType, functools.partial, functools.partialmethod) when safe_mode=True (the default). The walk runs before cls.from_config, build_from_config, and compile_from_config can see the config, so callables in config["config"], config["build_config"], config["compile_config"], or any other outer-level key are all covered.

safe_mode=False preserves the previous silent-accept behavior as an explicit opt-out, mirroring the existing opt-out for lambda deserialization.

Class objects and class instances themselves are not rejected (they have different semantics from raw functions). Bound methods (MethodType) are rejected since they are effectively ad-hoc function references.

Existing __lambda__ safe_mode behavior is preserved: the walk sees the serialized bytecode as a string and passes through, so the pre-existing "arbitrary code execution" error still fires downstream under safe_mode=True.

Contributor Agreement

Please check all boxes below before submitting your PR for review:

  • I am a human, and not a bot.
  • I will be responsible for responding to review comments in a timely manner.
  • I will work with the maintainers to push this PR forward until submission.

Copy link
Copy Markdown
Contributor

@gemini-code-assist gemini-code-assist Bot left a comment

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

Code Review

This pull request implements a security measure to block raw Python callables during Keras object deserialization when safe_mode is active. It introduces _reject_raw_callables, a recursive function that identifies and rejects unauthorized types such as functions, methods, and functools.partial objects within configurations. The review feedback highlights a potential performance bottleneck in the recursive traversal and suggests optimizing the function by skipping checks for PLAIN_TYPES and None values, which are common leaf nodes.

Comment on lines +418 to +433
def _reject_raw_callables(config, path="config"):
if isinstance(config, _RAW_CALLABLE_TYPES):
raise TypeError(
f"Received a raw Python callable at {path!r} during "
"deserialization. Configs must contain only serializable "
"values; wrap custom callables with "
"`@keras.saving.register_keras_serializable()` and pass the "
"serialized dict, or pass `safe_mode=False` to allow raw "
f"callables. Received: {config!r}"
)
if isinstance(config, dict):
for k, v in config.items():
_reject_raw_callables(v, f"{path}.{k}")
elif isinstance(config, (list, tuple)):
for i, v in enumerate(config):
_reject_raw_callables(v, f"{path}[{i}]")
Copy link
Copy Markdown
Contributor

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

medium

The recursive walk in _reject_raw_callables can be a performance bottleneck for models with large configurations (e.g., layers containing large lists of constants or hyperparameters). Since _reject_raw_callables is called on every class instantiation during deserialization, and it is itself recursive, it may redundantly walk nested structures multiple times.

To improve efficiency, we should skip recursion for PLAIN_TYPES and None values, which are the most common leaf nodes in a Keras config. Adding an early exit at the top of the function and checking within the loops will significantly reduce function call overhead and string formatting for large collections.

def _reject_raw_callables(config, path="config"):
    if config is None or isinstance(config, PLAIN_TYPES):
        return
    if isinstance(config, _RAW_CALLABLE_TYPES):
        raise TypeError(
            f"Received a raw Python callable at {path!r} during "
            "deserialization. Configs must contain only serializable "
            "values; wrap custom callables with "
            "`@keras.saving.register_keras_serializable()` and pass the "
            "serialized dict, or pass `safe_mode=False` to allow raw "
            f"callables. Received: {config!r}"
        )
    if isinstance(config, dict):
        for k, v in config.items():
            if v is None or isinstance(v, PLAIN_TYPES):
                continue
            _reject_raw_callables(v, f"{path}.{k}")
    elif isinstance(config, (list, tuple)):
        for i, v in enumerate(config):
            if v is None or isinstance(v, PLAIN_TYPES):
                continue
            _reject_raw_callables(v, f"{path}[{i}]")

Copy link
Copy Markdown
Contributor Author

@Saumay Saumay Apr 21, 2026

Choose a reason for hiding this comment

The reason will be displayed to describe this comment to others. Learn more.

I agree with this. Current placement walks the outer config once per class-path entry, and nested deserialize_keras_object calls on sub-objects re-walk their own subtrees, so worst case is O(depth × nodes). For typical model sizes (10-100 layers, each config 10-100 fields) this is a small constant overhead, but I agree it compounds on deeply nested configs (e.g. a Transformer stack).

Two mitigation options if this becomes a concern:

  1. Track walked ids in a thread-local set, skip re-walks
  2. Walk only at the top-level deserialize_keras_object entry by using a depth flag

Happy to implement either if the maintainer wants the guard before merge.

@Saumay Saumay force-pushed the fix-deserialize-raw-callables branch 2 times, most recently from cac4449 to adcc78c Compare April 21, 2026 04:40
@codecov-commenter
Copy link
Copy Markdown

codecov-commenter commented Apr 21, 2026

Codecov Report

✅ All modified and coverable lines are covered by tests.
✅ Project coverage is 84.10%. Comparing base (8f5a00c) to head (4dd25e2).

Additional details and impacted files
@@            Coverage Diff             @@
##           master   #22736      +/-   ##
==========================================
- Coverage   84.56%   84.10%   -0.46%     
==========================================
  Files         462      462              
  Lines       66767    66780      +13     
  Branches    10897    10903       +6     
==========================================
- Hits        56459    56167     -292     
- Misses       7441     7757     +316     
+ Partials     2867     2856      -11     
Flag Coverage Δ
keras 83.92% <100.00%> (-0.45%) ⬇️
keras-cpu 83.92% <100.00%> (+<0.01%) ⬆️
keras-gpu ?
keras-jax 58.27% <100.00%> (-0.18%) ⬇️
keras-numpy 53.75% <100.00%> (+<0.01%) ⬆️
keras-openvino 59.55% <100.00%> (+<0.01%) ⬆️
keras-tensorflow 59.68% <100.00%> (-0.15%) ⬇️
keras-torch 58.43% <100.00%> (-0.16%) ⬇️
keras-tpu ?

Flags with carried forward coverage won't be shown. Click here to find out more.

☔ View full report in Codecov by Sentry.
📢 Have feedback on the report? Share it here.

🚀 New features to boost your workflow:
  • ❄️ Test Analytics: Detect flaky tests, report on failures, and find test suite problems.

@Saumay Saumay force-pushed the fix-deserialize-raw-callables branch 5 times, most recently from 1a00eaf to 5dd853d Compare April 21, 2026 07:55
@keerthanakadiri keerthanakadiri added the stat:awaiting keras-eng Awaiting response from Keras engineer label Apr 23, 2026
Previously `deserialize_keras_object` silently accepted raw Python
functions embedded in a config's inner values, storing them on the
reconstructed object and invoking them later (e.g. during `build()`).
Raw callables are never a valid serialization target (JSON cannot
represent them) and can lead to unexpected execution when configs
come from untrusted sources.

Reject raw callables (function, method, builtin, functools.partial)
recursively through `inner_config` before `cls.from_config` when
`safe_mode=True` (the default). Users who need to pass raw callables
can opt out via `safe_mode=False`, matching the existing opt-out for
lambda deserialization.

Class objects and class instances are unaffected, since they are
legitimate config values (layer instances, constraint instances, etc.).

Fixes keras-team#22705.
@Saumay Saumay force-pushed the fix-deserialize-raw-callables branch from 5dd853d to 4dd25e2 Compare April 26, 2026 03:25
@jeffcarp jeffcarp self-requested a review April 28, 2026 15:30
Sign up for free to join this conversation on GitHub. Already have an account? Sign in to comment

Labels

awaiting review size:M stat:awaiting keras-eng Awaiting response from Keras engineer

Projects

None yet

Development

Successfully merging this pull request may close these issues.

[Bug] deserialize_keras_object accepts callable objects in config without validation

4 participants