Skip to content

Commit c2e6e56

Browse files
committed
feat: body add one/all/ref validator
1 parent cf1b38c commit c2e6e56

1 file changed

Lines changed: 85 additions & 32 deletions

File tree

src/request/axum.rs

Lines changed: 85 additions & 32 deletions
Original file line numberDiff line numberDiff line change
@@ -14,6 +14,7 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17+
use crate::model::parse;
1718
use crate::model::parse::{Format, In, Method, OpenAPI};
1819
use crate::request::validator::ValidateRequest;
1920
use anyhow::{Context, Result};
@@ -153,51 +154,50 @@ impl ValidateRequest for RequestData {
153154
}
154155

155156
fn body(&self, open_api: &OpenAPI) -> Result<()> {
157+
let body = self
158+
.body
159+
.as_ref()
160+
.ok_or_else(|| anyhow::anyhow!("Missing body"))?;
156161
let path = open_api
157162
.paths
158163
.get(self.path.as_str())
159164
.context("Path not found")?;
160-
161165
let path_base = path
162166
.get(&Method::Post)
163167
.context("Post method not defined for this path")?;
164168

165169
if let Some(request) = &path_base.request {
166-
let map = &request.content;
167-
168-
if let Some(body) = &self.body {
169-
let request_fields: HashMap<String, Value> = serde_json::from_slice(body)?;
170-
171-
for (key, value) in map {
172-
if let Some(field) = request_fields.get(key) {
173-
match value.schema.format {
174-
Format::UUID => {
175-
if let Some(v) = field.as_str() {
176-
uuid::Uuid::parse_str(v).map_err(|_| {
177-
anyhow::anyhow!(
178-
"Invalid UUID format for query parameter '{}': '{:?}'",
179-
key,
180-
value
181-
)
182-
})?;
183-
} else {
184-
return Err(anyhow::anyhow!(
185-
"Invalid UUID format for query parameter '{}'",
186-
key
187-
));
188-
}
189-
}
190-
_ => {
191-
return Err(anyhow::anyhow!(
192-
"Unsupported format '{:?}' for query parameter '{}'",
193-
value.schema.format,
194-
key
195-
));
196-
}
170+
let request_fields: HashMap<String, Value> = serde_json::from_slice(body)?;
171+
let refs = collect_schema_refs(&request.content);
172+
173+
for (key, media_type) in &request.content {
174+
if let Some(field) = request_fields.get(key) {
175+
validate_field_format(key, field, media_type.schema.format.clone())?;
176+
}
177+
}
178+
179+
let mut requireds = HashSet::new();
180+
181+
if let Some(components) = &open_api.components {
182+
for schema_ref in &refs {
183+
if let Some(schema) = components.schemas.get(*schema_ref) {
184+
requireds.extend(schema.required.iter().cloned());
185+
186+
if let Some(properties) = &schema.properties {
187+
validate_schema_properties(&request_fields, properties)?;
197188
}
198189
}
199190
}
200191
}
192+
193+
for key in &requireds {
194+
if !request_fields.contains_key(key) {
195+
return Err(anyhow::anyhow!(
196+
"Missing required query parameter: '{}'",
197+
key
198+
));
199+
}
200+
}
201201
}
202202

203203
Ok(())
@@ -225,3 +225,56 @@ fn validate_format(format: &Format, value: &str, key: &str) -> Result<()> {
225225
}
226226
Ok(())
227227
}
228+
229+
fn collect_schema_refs(content: &HashMap<String, parse::BaseContent>) -> Vec<&str> {
230+
let mut refs = Vec::new();
231+
for media_type in content.values() {
232+
if let Some(r) = &media_type.schema._ref {
233+
refs.push(r.as_str());
234+
}
235+
if let Some(one_of) = &media_type.schema.one_of {
236+
refs.extend(one_of.iter().filter_map(|s| s._ref.as_deref()));
237+
}
238+
if let Some(all_of) = &media_type.schema.all_of {
239+
refs.extend(all_of.iter().filter_map(|s| s._ref.as_deref()));
240+
}
241+
}
242+
refs
243+
}
244+
245+
fn validate_field_format(key: &str, value: &Value, format: Format) -> Result<()> {
246+
match format {
247+
Format::UUID => {
248+
let str_val = value.as_str().ok_or_else(|| {
249+
anyhow::anyhow!("Invalid UUID format for query parameter '{}'", key)
250+
})?;
251+
uuid::Uuid::parse_str(str_val).map_err(|_| {
252+
anyhow::anyhow!(
253+
"Invalid UUID format for query parameter '{}': '{}'",
254+
key,
255+
str_val
256+
)
257+
})?;
258+
}
259+
_ => {
260+
return Err(anyhow::anyhow!(
261+
"Unsupported format '{:?}' for query parameter '{}'",
262+
format,
263+
key
264+
));
265+
}
266+
}
267+
Ok(())
268+
}
269+
270+
fn validate_schema_properties(
271+
request_fields: &HashMap<String, Value>,
272+
properties: &HashMap<String, parse::Properties>,
273+
) -> Result<()> {
274+
for (key, prop) in properties {
275+
if let Some(value) = request_fields.get(key) {
276+
validate_field_format(key, value, prop.format.clone())?;
277+
}
278+
}
279+
Ok(())
280+
}

0 commit comments

Comments
 (0)