Skip to content

Commit 0f4a4fa

Browse files
authored
Merge pull request #1524 from rstudio/book-driven-updates
Book driven updates
2 parents bbcb58f + 52d30e9 commit 0f4a4fa

16 files changed

+194
-58
lines changed

NAMESPACE

Lines changed: 6 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -1,16 +1,19 @@
11
# Generated by roxygen2: do not edit by hand
22

33
S3method("!=",keras_shape)
4+
S3method("$",python.builtin.super)
45
S3method("$",python_builtin_super_getter)
56
S3method("$<-",keras.src.callbacks.callback.Callback)
67
S3method("+",keras.src.backend.common.keras_tensor.KerasTensor)
78
S3method("==",keras.src.backend.common.keras_tensor.KerasTensor)
89
S3method("==",keras_shape)
910
S3method("[",keras_shape)
11+
S3method("[[",python.builtin.super)
1012
S3method("[[",python_builtin_super_getter)
1113
S3method(Arg,keras.src.backend.Tensor)
1214
S3method(Arg,keras.src.backend.common.keras_tensor.KerasTensor)
1315
S3method(Summary,keras_shape)
16+
S3method(all,equal.numpy.ndarray)
1417
S3method(as.array,jax.Array)
1518
S3method(as.array,jaxlib._jax.ArrayImpl)
1619
S3method(as.array,jaxlib.xla_extension.ArrayImpl)
@@ -494,6 +497,7 @@ export(metric_sum)
494497
export(metric_top_k_categorical_accuracy)
495498
export(metric_true_negatives)
496499
export(metric_true_positives)
500+
export(named_list)
497501
export(new_callback_class)
498502
export(new_layer_class)
499503
export(new_learning_rate_schedule_class)
@@ -806,6 +810,7 @@ export(optimizer_sgd)
806810
export(pad_sequences)
807811
export(pop_layer)
808812
export(predict_on_batch)
813+
export(py_help)
809814
export(py_require)
810815
export(py_to_r)
811816
export(quantize_weights)
@@ -887,6 +892,7 @@ importFrom(reticulate,py_func)
887892
importFrom(reticulate,py_get_attr)
888893
importFrom(reticulate,py_get_item)
889894
importFrom(reticulate,py_has_attr)
895+
importFrom(reticulate,py_help)
890896
importFrom(reticulate,py_install)
891897
importFrom(reticulate,py_is_null_xptr)
892898
importFrom(reticulate,py_iterator)

NEWS.md

Lines changed: 60 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -1,60 +1,88 @@
11
# keras3 (development version)
22

3-
- Added S3 methods for JAX array: `str`, `as.array`, `as.double`, `as.integer`, `as.numeric`.
4-
5-
- Added `str` S3 method for Keras Variables.
6-
7-
- `layer_reshape()` can now accept `-1` as a sentinel for an automatically calculated axis size.
3+
- Expanded numeric operations with `op_layer_normalization()`, `op_cbrt()`,
4+
`op_corrcoef()`, `op_deg2rad()`, `op_heaviside()`, the new `op_sparse_sigmoid()`
5+
plus matching `activation_sparse_sigmoid()`, and an `attn_logits_soft_cap`
6+
argument for `op_dot_product_attention()`.
87

9-
- Updated dependencies declared by `use_backend("jax", gpu=TRUE)`
10-
for compatability with `keras-hub`.
8+
- Added signal window operations: `op_bartlett()`, `op_blackman()`,
9+
`op_hamming()`, `op_hanning()`, and `op_kaiser()`.
1110

12-
- Added training loop configuration helpers:
13-
`config_max_epochs()`, `config_set_max_epochs()`, `config_max_steps_per_epoch()`,
14-
and `config_set_max_steps_per_epoch()`. The caps can also be set via the
15-
`KERAS_MAX_EPOCHS` and `KERAS_MAX_STEPS_PER_EPOCH` environment variables.
16-
Added `config_is_nnx_enabled()` to check whether JAX NNX features are enabled.
11+
- Added `loss_categorical_generalized_cross_entropy()` for training with noisy
12+
labels.
1713

1814
- LoRA-enabled layers (`layer_dense()`, `layer_embedding()`, `layer_einsum_dense()`)
1915
gain a `lora_alpha` argument to scale the adaptation delta independently of the
2016
chosen rank.
2117

22-
- `keras_variable()` now accepts a `synchronization` argument for distributed
23-
strategies.
18+
- Added complex-valued helpers: S3 `Arg()` methods for tensors, `op_angle()`,
19+
and conversions `op_view_as_real()` / `op_view_as_complex()`.
2420

25-
- `Layer$add_weight()` gains an `overwrite_with_gradient` option and
26-
layers now provide a `symbolic_call()` method.
21+
- Added the Muon optimizer via `optimizer_muon()`.
22+
23+
- Added elastic deformation utilities for images: `layer_random_elastic_transform()`
24+
and the lower-level `op_image_elastic_transform()`.
2725

2826
- Transposed convolution utilities now follow the latest Keras API:
2927
`op_conv_transpose()` defaults `strides = 1` and the `layer_conv_*_transpose()`
3028
layers expose `output_padding` for precise shape control.
3129

32-
- `layer_torch_module_wrapper()` gains an `output_shape` argument to help Keras
33-
infer shapes when wrapping PyTorch modules.
30+
- `register_keras_serializable()` now returns a registered Python callable,
31+
making it easier to use with bare R functions.
3432

3533
- `save_model_weights()` adds a `max_shard_size` argument to split large weight
3634
files into manageable shards.
3735

38-
- Added elastic deformation utilities for images: `layer_random_elastic_transform()`
39-
and the lower-level `op_image_elastic_transform()`.
36+
- `keras_variable()` now accepts a `synchronization` argument for distributed
37+
strategies.
4038

41-
- Added `loss_categorical_generalized_cross_entropy()` for training with noisy
42-
labels.
39+
- `layer_layer_normalization()` removes the `rms_scaling` argument.
4340

44-
- Added the Muon optimizer via `optimizer_muon()`.
41+
- `layer_reshape()` can now accept `-1` as a sentinel for an automatically calculated axis size.
4542

46-
- Added complex-valued helpers: S3 `Arg()` methods for tensors, `op_angle()`,
47-
and conversions `op_view_as_real()` / `op_view_as_complex()`.
43+
- `layer_torch_module_wrapper()` gains an `output_shape` argument to help Keras
44+
infer shapes when wrapping PyTorch modules.
4845

49-
- Added signal window operations: `op_bartlett()`, `op_blackman()`,
50-
`op_hamming()`, `op_hanning()`, and `op_kaiser()`.
46+
- `Layer$add_weight()` gains an `overwrite_with_gradient` option and
47+
layers now provide a `symbolic_call()` method.
5148

52-
- Expanded numeric operations with `op_layer_normalization()`, `op_cbrt()`,
53-
`op_corrcoef()`, `op_deg2rad()`, `op_heaviside()`, the new `op_sparse_sigmoid()`
54-
plus matching `activation_sparse_sigmoid()`, and an `attn_logits_soft_cap`
55-
argument for `op_dot_product_attention()`.
49+
- Added `str()` S3 method for Keras Variables.
5650

57-
- `layer_layer_normalization()` removes the `rms_scaling` argument.
51+
- Added S3 methods for JAX array:
52+
`str()`, `as.array()`, `as.double()`, `as.integer()`, `as.numeric()`.
53+
54+
- Added base-array compatibility methods for backend tensors: `t()`,
55+
`aperm()`, and `all.equal()`.
56+
57+
- Added `pillar::type_sum()` for JAX variables and `JaxVariable`;
58+
extended `str()` coverage to the new JAX variable class.
59+
60+
- `config_max_epochs()`, `config_set_max_epochs()`, `config_max_steps_per_epoch()`,
61+
and `config_set_max_steps_per_epoch()`. The caps can also be set via the
62+
`KERAS_MAX_EPOCHS` and `KERAS_MAX_STEPS_PER_EPOCH` environment variables.
63+
Added `config_is_nnx_enabled()` to check whether JAX NNX features are enabled.
64+
65+
- Built-in dataset loaders now accept `convert = FALSE` to return NumPy arrays
66+
instead of R arrays.
67+
68+
- Updated `plot(history, theme_bw = TRUE)` for `ggplot2` 3.4.0
69+
compatibility.
70+
71+
- `plot(model)` DPI is now globally configurable via
72+
`options(keras.plot.model.dpi = )`, (defaults to `200`).
73+
74+
- Reexported reticulate functions: `py_help()`, `py_to_r()`, `r_to_py()`,
75+
`py_require()`, and `import()`.
76+
77+
- Support `super()$initialize()` in subclassed Keras classes; improved
78+
`super()` behavior in subclasses.
79+
80+
- Updated dependencies declared by `use_backend("jax", gpu=TRUE)`
81+
for compatability with `keras-hub`.
82+
83+
- Exported `named_list()` utility.
84+
85+
- Fixed an issue when switching backends twice in a row.
5886

5987
# keras3 1.4.0
6088

R/install.R

Lines changed: 1 addition & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -388,7 +388,7 @@ uv_unset_override_never_tensorflow <- function() {
388388
if (is.na(override)) return()
389389
cpu_override <- pkg_file("never-tensorflow-override.txt")
390390
if (override == cpu_override) {
391-
Sys.unsetenv(override)
391+
Sys.unsetenv("UV_OVERRIDE")
392392
} else {
393393
new <- gsub(cpu_override, "", override, fixed = TRUE)
394394
new <- gsub(" +", " ", new)

R/model-persistence.R

Lines changed: 3 additions & 3 deletions
Original file line numberDiff line numberDiff line change
@@ -585,8 +585,8 @@ function (object, filepath, call_endpoint = "serve", call_training_endpoint = NU
585585
#' @param object
586586
#' A keras object.
587587
#'
588-
#' @returns `object` is returned invisibly, for convenient piping. This is
589-
#' primarily called for side effects.
588+
#' @returns The registered `object` (and converted) is returned. This returned object is what you
589+
#' should must use when building and serializing the model.
590590
#' @export
591591
#' @family saving and loading functions
592592
#' @family serialization utilities
@@ -605,7 +605,7 @@ function (object, name = NULL, package = NULL)
605605
c("", "base", "R_GlobalEnv"), "Custom")
606606

607607
keras$saving$register_keras_serializable(package, name)(py_object)
608-
invisible(object)
608+
py_object
609609
}
610610

611611

R/package.R

Lines changed: 16 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -186,7 +186,12 @@ keras <- NULL
186186
keras <- import("keras")
187187
convert_to_tensor <- import("keras.ops", convert = FALSE)$convert_to_tensor
188188
with(keras$device("cpu:0"), {
189-
backend_tensor_class <- class(convert_to_tensor(array(1L)))[1L]
189+
all_backend_tensor_s3_classes <- class(convert_to_tensor(array(1L)))
190+
backend_tensor_class <- all_backend_tensor_s3_classes[1L]
191+
if ("jax.Array" %in% all_backend_tensor_s3_classes)
192+
backend_tensor_class <- "jax.Array"
193+
# message("setting methods on backend_tensor_class: ", backend_tensor_class,
194+
# "\nother options: ", paste0(all_backend_tensor_s3_classes, collapse = " "))
190195
})
191196
symbolic_tensor_class <- nameOfClass__python.builtin.type(keras$KerasTensor)
192197

@@ -207,6 +212,9 @@ keras <- NULL
207212
registerS3method("as.array", backend_tensor_class, op_convert_to_array, baseenv())
208213
registerS3method("^", backend_tensor_class, `^__keras.backend.tensor`, baseenv())
209214
registerS3method("%*%", backend_tensor_class, op_matmul, baseenv())
215+
registerS3method("t", backend_tensor_class, op_transpose, baseenv())
216+
registerS3method("aperm", backend_tensor_class, op_transpose, baseenv())
217+
registerS3method("all.equal", backend_tensor_class, all.equal.numpy.ndarray, baseenv())
210218

211219
if(keras$config$backend() == "jax") {
212220
for(py_type in import("jax")$Array$`__subclasses__`()) {
@@ -271,6 +279,13 @@ keras <- NULL
271279

272280
}
273281

282+
## should this live in reticulate?? probably...
283+
#' @export
284+
all.equal.numpy.ndarray <- function(target, current, ...) {
285+
# or use numpy.allequal?
286+
all.equal(as.array(target), as.array(current), ...)
287+
}
288+
274289

275290
at.keras_backend_tensor <- function(object, name) {
276291
out <- rlang::env_clone(object)

R/py-classes.R

Lines changed: 19 additions & 7 deletions
Original file line numberDiff line numberDiff line change
@@ -119,12 +119,16 @@ function(classname,
119119
type = `__class__`,
120120
object_or_type = base::get("self", envir = base::parent.frame()))
121121
{
122-
convert <- base::get("convert", envir = base::as.environment(object_or_type))
123-
py_builtins <- reticulate::import_builtins(convert)
124-
reticulate::py_call(py_builtins$super, type, object_or_type)
125-
}
122+
convert <- base::get("convert", object_or_type)
123+
py_super <- reticulate::py_eval(
124+
"__import__('builtins').super",
125+
convert = convert
126+
)
127+
py_super(type, object_or_type)
128+
}
126129
class(super) <- "python_builtin_super_getter"
127-
}))
130+
})
131+
)
128132

129133

130134
py_class
@@ -137,15 +141,23 @@ function(classname,
137141
#' @export
138142
`$.python_builtin_super_getter` <- function(x, name) {
139143
super <- do.call(x, list(), envir = parent.frame()) # call super()
144+
`[[.python.builtin.super`(super, name)
145+
}
146+
147+
#' @export
148+
`[[.python_builtin_super_getter` <- `$.python_builtin_super_getter`
149+
150+
#' @export
151+
`$.python.builtin.super` <- function(x, name) {
140152
name <- switch(name, initialize = "__init__", finalize = "__del__", name)
141-
out <- py_get_attr(super, name)
153+
out <- py_get_attr(x, name)
142154
convert <- get0("convert", as.environment(out), inherits = FALSE,
143155
ifnotfound = TRUE)
144156
if (convert) py_to_r(out) else out
145157
}
146158

147159
#' @export
148-
`[[.python_builtin_super_getter` <- `$.python_builtin_super_getter`
160+
`[[.python.builtin.super` <- `$.python.builtin.super`
149161

150162
# No .DollarNames.python_builtin_super_getter because the python.builtin.super
151163
# object doesn't have populated attributes itself, only a dynamic `__getattr__`

R/r-utils.R

Lines changed: 23 additions & 4 deletions
Original file line numberDiff line numberDiff line change
@@ -40,15 +40,34 @@ drop_nulls <- function(x, i = NULL) {
4040
x[!drop]
4141
}
4242

43+
#' Create a named list from arguments
44+
#'
45+
#' Constructs a list from the provided arguments where all elements are named.
46+
#' This wraps [rlang::dots_list()] but changes two defaults:
47+
#' - `.named` is set to `TRUE`
48+
#' - `.homonyms` is set to `"error"`
49+
#'
50+
#' Other parameters retain their defaults from [rlang::dots_list()]:
51+
#' - `.ignore_empty = "trailing"`
52+
#' - `.preserve_empty = FALSE`
53+
#' - `.check_assign = FALSE`
54+
#'
55+
#' @inheritParams rlang::dots_list
56+
#'
57+
#' @inheritParams dots_list
58+
#'
59+
#' @return A named list.
60+
#'
61+
#' @seealso [rlang::dots_list()]
62+
#'
63+
#' @export
4364
#' @importFrom rlang dots_list
44-
# identical to rlang::list2(), except .named = TRUE
4565
named_list <- function(...)
4666
dots_list(...,
47-
.named = TRUE,
48-
# not the default
67+
.named = TRUE, # not default
4968
.ignore_empty = "trailing",
5069
.preserve_empty = FALSE,
51-
.homonyms = "error",
70+
.homonyms = "error", # not default
5271
.check_assign = FALSE)
5372

5473
`append1<-` <- function(x, value) {

R/reexports.R

Lines changed: 3 additions & 0 deletions
Original file line numberDiff line numberDiff line change
@@ -77,6 +77,9 @@ reticulate::py_to_r
7777
#' @export
7878
reticulate::r_to_py
7979

80+
#' @export
81+
reticulate::py_help
82+
8083
#' @importFrom tensorflow tensorboard
8184
#' @export
8285
tensorflow::tensorboard

R/utils.R

Lines changed: 3 additions & 1 deletion
Original file line numberDiff line numberDiff line change
@@ -369,6 +369,8 @@ to_categorical <-
369369
function (x, num_classes = NULL)
370370
{
371371
if (inherits(x, "factor")) {
372+
# if (length(DIM(x)) == 1)
373+
# return(diag(nrow = num_classes %||% length(levels(x)))[as.integer(x), ])
372374
x <- array(as.integer(x) - 1L, dim = dim(x) %||% length(x))
373375
if (is.null(num_classes))
374376
num_classes <- length(levels(x))
@@ -631,7 +633,7 @@ function(x,
631633
...,
632634
rankdir = "TB",
633635
expand_nested = FALSE,
634-
dpi = 200,
636+
dpi = getOption("keras.plot.model.dpi", 200L),
635637
layer_range = NULL,
636638
show_layer_activations = FALSE,
637639
show_trainable = NA,

man/deserialize_keras_object.Rd

Lines changed: 11 additions & 1 deletion
Some generated files are not rendered by default. Learn more about customizing how changed files appear on GitHub.

0 commit comments

Comments
 (0)