Skip to content

Commit 012b901

Browse files
authored
rust: define blob key encoding (#4545)
Summary: The data provider RPC protocol requires servers to emit “blob keys”, opaque, URL-safe tokens that uniquely identify blobs. This patch defines an in-memory `BlobKey` struct with encoding and decoding. The wire format is URL-safe base64 (with no padding) over JSON. The seemingly unrelated changes are needed because loading `serde_json` adds trait implementations that make some type inference inadmissible: <rust-lang/rust#80964> Test Plan: Unit tests included. wchargin-branch: rust-blob-keys
1 parent 0fe35b5 commit 012b901

File tree

5 files changed

+152
-2
lines changed

5 files changed

+152
-2
lines changed

tensorboard/data/server/BUILD

+4
Original file line numberDiff line numberDiff line change
@@ -27,6 +27,7 @@ rust_library(
2727
name = "rustboard_core",
2828
srcs = [
2929
"lib.rs",
30+
"blob_key.rs",
3031
"cli.rs",
3132
"commit.rs",
3233
"data_compat.rs",
@@ -44,6 +45,7 @@ rust_library(
4445
] + _checked_in_proto_files,
4546
edition = "2018",
4647
deps = [
48+
"//third_party/rust:base64",
4749
"//third_party/rust:byteorder",
4850
"//third_party/rust:clap",
4951
"//third_party/rust:crc",
@@ -53,6 +55,8 @@ rust_library(
5355
"//third_party/rust:prost",
5456
"//third_party/rust:rand",
5557
"//third_party/rust:rand_chacha",
58+
"//third_party/rust:serde",
59+
"//third_party/rust:serde_json",
5660
"//third_party/rust:thiserror",
5761
"//third_party/rust:tokio",
5862
"//third_party/rust:tonic",

tensorboard/data/server/blob_key.rs

+145
Original file line numberDiff line numberDiff line change
@@ -0,0 +1,145 @@
1+
/* Copyright 2021 The TensorFlow Authors. All Rights Reserved.
2+
3+
Licensed under the Apache License, Version 2.0 (the "License");
4+
you may not use this file except in compliance with the License.
5+
You may obtain a copy of the License at
6+
7+
http://www.apache.org/licenses/LICENSE-2.0
8+
9+
Unless required by applicable law or agreed to in writing, software
10+
distributed under the License is distributed on an "AS IS" BASIS,
11+
WITHOUT WARRANTIES OR CONDITIONS OF ANY KIND, either express or implied.
12+
See the License for the specific language governing permissions and
13+
limitations under the License.
14+
==============================================================================*/
15+
16+
//! Opaque, URL-safe blob keys.
17+
18+
use serde::{Deserialize, Serialize};
19+
use std::borrow::Cow;
20+
use std::convert::TryFrom;
21+
use std::fmt::Display;
22+
use std::str::FromStr;
23+
24+
use crate::types::Step;
25+
26+
const BASE_64_CONFIG: base64::Config = base64::URL_SAFE_NO_PAD;
27+
28+
/// Unique identifier for a blob.
29+
///
30+
/// Blob keys are returned by the `ReadBlobSequences` RPC, and can be dereferenced via the
31+
/// `ReadBlob` RPC.
32+
///
33+
/// Blob keys implement [`Display`] and [`FromStr`], which should be used for encoding and
34+
/// decoding, respectively. The `Display` format of a blob key is URL-safe.
35+
#[derive(Debug, Clone, PartialEq, Eq)]
36+
pub struct BlobKey<'a> {
37+
pub experiment_id: Cow<'a, str>,
38+
pub run: Cow<'a, str>,
39+
pub tag: Cow<'a, str>,
40+
pub step: Step,
41+
pub index: usize,
42+
}
43+
44+
/// Helper struct to encode `BlobKey`s as tuples (rather than objects with named keys) and to use
45+
/// portable integers over the wire.
46+
#[derive(Debug, Serialize, Deserialize)]
47+
struct WireBlobKey<'a>(&'a str, &'a str, &'a str, i64, u64);
48+
49+
/// An error returned when parsing a `BlobKey`.
50+
#[derive(Debug, thiserror::Error)]
51+
pub enum ParseBlobKeyError {
52+
#[error("invalid base-64: {}", .0)]
53+
BadBase64(base64::DecodeError),
54+
#[error("invalid JSON: {}", .0)]
55+
BadJson(serde_json::Error),
56+
#[error("index does not fit in memory on this system: {} > {}", .0, usize::MAX)]
57+
BadIndex(u64),
58+
}
59+
60+
impl<'a> FromStr for BlobKey<'a> {
61+
type Err = ParseBlobKeyError;
62+
63+
fn from_str(s: &str) -> Result<Self, Self::Err> {
64+
let buf = base64::decode_config(s, BASE_64_CONFIG).map_err(ParseBlobKeyError::BadBase64)?;
65+
let WireBlobKey(experiment_id, run, tag, step, index) =
66+
serde_json::from_slice(&buf).map_err(ParseBlobKeyError::BadJson)?;
67+
let index = usize::try_from(index).map_err(|_| ParseBlobKeyError::BadIndex(index))?;
68+
Ok(BlobKey {
69+
experiment_id: Cow::Owned(experiment_id.into()),
70+
run: Cow::Owned(run.into()),
71+
tag: Cow::Owned(tag.into()),
72+
step: Step(step),
73+
index,
74+
})
75+
}
76+
}
77+
78+
impl<'a> Display for BlobKey<'a> {
79+
fn fmt(&self, f: &mut std::fmt::Formatter<'_>) -> std::fmt::Result {
80+
use base64::display::Base64Display;
81+
let wire = WireBlobKey(
82+
&self.experiment_id,
83+
&self.run,
84+
&self.tag,
85+
self.step.0,
86+
self.index as u64,
87+
);
88+
let json =
89+
serde_json::to_string(&wire).expect("wire blob keys should always be serializable");
90+
Base64Display::with_config(json.as_bytes(), BASE_64_CONFIG).fmt(f)
91+
}
92+
}
93+
94+
#[cfg(test)]
95+
mod tests {
96+
use super::*;
97+
98+
#[test]
99+
fn test_roundtrip() {
100+
let key = BlobKey {
101+
experiment_id: Cow::Borrowed("123"),
102+
run: Cow::Owned("mnist".to_string()),
103+
tag: Cow::Borrowed("input_image"),
104+
step: Step(777),
105+
index: 123,
106+
};
107+
assert_eq!(key.to_string().parse::<BlobKey>().unwrap(), key);
108+
}
109+
110+
#[test]
111+
fn test_no_padding() {
112+
for eid_length in 0..10 {
113+
let key = BlobKey {
114+
experiment_id: Cow::Owned("x".repeat(eid_length)),
115+
run: Cow::Borrowed("run"),
116+
tag: Cow::Borrowed("tag"),
117+
step: Step(0),
118+
index: 0,
119+
};
120+
let encoded = key.to_string();
121+
assert!(
122+
!encoded.ends_with('='),
123+
"encoded form should not end with '=': {:?} => {:?}",
124+
key,
125+
encoded,
126+
);
127+
}
128+
}
129+
130+
#[test]
131+
fn test_bad_base64() {
132+
match "???".parse::<BlobKey>().unwrap_err() {
133+
ParseBlobKeyError::BadBase64(_) => (),
134+
other => panic!("expected BadBase64(_), got {:?}", other),
135+
};
136+
}
137+
138+
#[test]
139+
fn test_bad_json() {
140+
match "AAAAAA".parse::<BlobKey>().unwrap_err() {
141+
ParseBlobKeyError::BadJson(_) => (),
142+
other => panic!("expected BadJson(_), got {:?}", other),
143+
};
144+
}
145+
}

tensorboard/data/server/data_compat.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -453,7 +453,7 @@ mod tests {
453453
fn test() {
454454
let md = GraphDefValue::initial_metadata();
455455
assert_eq!(&md.plugin_data.unwrap().plugin_name, GRAPHS_PLUGIN_NAME);
456-
assert_eq!(md.data_class, pb::DataClass::BlobSequence.into());
456+
assert_eq!(md.data_class, i32::from(pb::DataClass::BlobSequence));
457457
}
458458
}
459459

tensorboard/data/server/lib.rs

+1
Original file line numberDiff line numberDiff line change
@@ -17,6 +17,7 @@ limitations under the License.
1717
1818
#![allow(clippy::needless_update)] // https://github.com/rust-lang/rust-clippy/issues/6323
1919

20+
pub mod blob_key;
2021
pub mod cli;
2122
pub mod commit;
2223
pub mod data_compat;

tensorboard/data/server/server.rs

+1-1
Original file line numberDiff line numberDiff line change
@@ -626,6 +626,6 @@ mod tests {
626626
let map = run_tag_map!(res.runs);
627627
let train_run = &map[&Run("train".to_string())];
628628
let xent_data = &train_run[&Tag("xent".to_string())].data.as_ref().unwrap();
629-
assert_eq!(xent_data.value, Vec::new());
629+
assert_eq!(xent_data.value, Vec::<f32>::new());
630630
}
631631
}

0 commit comments

Comments
 (0)