Skip to content

Commit 2c15feb

Browse files
committed
feat: Multiplex the validator
1 parent 1b18784 commit 2c15feb

2 files changed

Lines changed: 293 additions & 279 deletions

File tree

src/request/axum.rs

Lines changed: 15 additions & 277 deletions
Original file line numberDiff line numberDiff line change
@@ -14,17 +14,14 @@
1414
* See the License for the specific language governing permissions and
1515
* limitations under the License.
1616
*/
17-
use crate::model::parse;
18-
use crate::model::parse::{Format, In, Method, OpenAPI};
19-
use crate::request::validator::ValidateRequest;
20-
use anyhow::{Context, Result};
17+
use crate::model::parse::OpenAPI;
18+
use crate::request;
19+
use crate::request::validator::{common_method, ValidateRequest};
20+
use anyhow::Result;
2121
use axum::body::{Body, Bytes};
2222
use axum::http::Request;
23-
use chrono::{DateTime, NaiveDate, NaiveTime};
2423
use serde_json::Value;
25-
use std::collections::{HashMap, HashSet};
26-
use std::net::{Ipv4Addr, Ipv6Addr};
27-
use std::str::FromStr;
24+
use std::collections::HashMap;
2825

2926
#[allow(dead_code)]
3027
pub struct RequestData {
@@ -39,126 +36,22 @@ impl ValidateRequest for RequestData {
3936
}
4037

4138
fn method(&self, open_api: &OpenAPI) -> Result<()> {
42-
let path = open_api
43-
.paths
44-
.get(self.path.as_str())
45-
.context("Path not found")?;
46-
47-
let method =
48-
Method::from_str(self.inner.method().as_str()).map_err(|e| anyhow::anyhow!(e))?;
49-
50-
if path.get(&method).is_none() {
51-
return Err(anyhow::anyhow!("Path is empty"));
52-
}
53-
Ok(())
39+
common_method(self.path.as_str(), self.inner.method().as_str(), open_api)
5440
}
5541

5642
fn query(&self, open_api: &OpenAPI) -> Result<()> {
57-
let path = open_api
58-
.paths
59-
.get(self.path.as_str())
60-
.context("Path not found")?;
61-
62-
if let Some(path_base) = path.get(&Method::Get) {
63-
let query_str = self.inner.uri().query().unwrap_or_default();
64-
let query_pairs: HashMap<_, _> = url::form_urlencoded::parse(query_str.as_bytes())
65-
.into_owned()
66-
.collect();
67-
68-
let mut requireds: HashSet<String> = HashSet::new();
43+
let query_str = self.inner.uri().query().unwrap_or_default();
6944

70-
if let Some(parameters) = &path_base.parameters {
71-
for parameter in parameters {
72-
if parameter._in != In::Query {
73-
continue;
74-
}
75-
76-
if let Some(value) = query_pairs.get(&parameter.name) {
77-
validate_field_format(
78-
&parameter.name,
79-
&Value::from(value.as_str()),
80-
parameter.schema.format.clone(),
81-
)?;
82-
}
83-
84-
let mut refs = Vec::new();
85-
if let Some(r) = &parameter.schema._ref {
86-
refs.push(r.as_str());
87-
}
88-
if let Some(one_of) = &parameter.schema.one_of {
89-
for s in one_of {
90-
if let Some(r) = &s._ref {
91-
refs.push(r.as_str());
92-
}
93-
}
94-
}
95-
if let Some(all_of) = &parameter.schema.all_of {
96-
for s in all_of {
97-
if let Some(r) = &s._ref {
98-
refs.push(r.as_str());
99-
}
100-
}
101-
}
102-
103-
for schema_ref in refs {
104-
if let Some(components) = &open_api.components {
105-
if let Some(schema) = components.schemas.get(schema_ref) {
106-
if !schema.required.is_empty() {
107-
requireds.extend(schema.required.clone());
108-
}
109-
if let Some(properties) = &schema.properties {
110-
for (key, prop) in properties {
111-
if let Some(value) = query_pairs.get(key) {
112-
validate_field_format(
113-
key,
114-
&Value::from(value.as_str()),
115-
prop.format.clone(),
116-
)?;
117-
}
118-
}
119-
}
120-
}
121-
}
122-
}
123-
}
124-
}
125-
126-
for key in &requireds {
127-
if !query_pairs.contains_key(key) {
128-
return Err(anyhow::anyhow!(
129-
"Missing required query parameter: '{}'",
130-
key
131-
));
132-
}
133-
}
134-
}
45+
let query_pairs: HashMap<_, _> = url::form_urlencoded::parse(query_str.as_bytes())
46+
.into_owned()
47+
.collect();
13548

136-
Ok(())
49+
request::validator::common_query(self.path.as_str(), query_pairs, open_api)
13750
}
13851

13952
fn path(&self, open_api: &OpenAPI) -> Result<()> {
140-
let path = open_api
141-
.paths
142-
.get(self.path.as_str())
143-
.context("Path not found")?;
144-
145-
if let Some(path_base) = path.get(&Method::Get) {
146-
let uri = self.inner.uri();
147-
148-
if let Some(parameters) = &path_base.parameters {
149-
if let Some(last_segment) = uri.path().rsplit('/').find(|s| !s.is_empty()) {
150-
for parameter in parameters {
151-
if parameter._in != In::Path {
152-
continue;
153-
}
154-
validate_field_format(
155-
&parameter.name,
156-
&Value::from(last_segment),
157-
parameter.schema.format.clone(),
158-
)?;
159-
}
160-
}
161-
}
53+
if let Some(last_segment) = self.inner.uri().path().rsplit('/').find(|s| !s.is_empty()) {
54+
request::validator::common_path(self.path.as_str(), last_segment, open_api)?
16255
}
16356

16457
Ok(())
@@ -172,162 +65,7 @@ impl ValidateRequest for RequestData {
17265
.body
17366
.as_ref()
17467
.ok_or_else(|| anyhow::anyhow!("Missing body"))?;
175-
let path = open_api
176-
.paths
177-
.get(self.path.as_str())
178-
.context("Path not found")?;
179-
let path_base = path
180-
.get(&Method::Post)
181-
.context("Post method not defined for this path")?;
182-
183-
if let Some(request) = &path_base.request {
184-
let request_fields: HashMap<String, Value> = serde_json::from_slice(body)?;
185-
let refs = collect_schema_refs(&request.content);
186-
187-
for (key, media_type) in &request.content {
188-
if let Some(field) = request_fields.get(key) {
189-
validate_field_format(key, field, media_type.schema.format.clone())?;
190-
}
191-
}
192-
193-
let mut requireds = HashSet::new();
194-
195-
if let Some(components) = &open_api.components {
196-
for schema_ref in refs {
197-
if let Some(last_slash_pos) = schema_ref.rfind('/') {
198-
let filename = &schema_ref[last_slash_pos + 1..];
199-
if let Some(schema) = components.schemas.get(filename) {
200-
requireds.extend(schema.required.iter().cloned());
201-
if let Some(properties) = &schema.properties {
202-
validate_schema_properties(&request_fields, properties)?;
203-
}
204-
}
205-
} else {
206-
return Err(anyhow::anyhow!(
207-
"Invalid schema reference: '{}'",
208-
schema_ref
209-
));
210-
}
211-
}
212-
}
213-
214-
for key in &requireds {
215-
if !request_fields.contains_key(key) {
216-
return Err(anyhow::anyhow!(
217-
"Missing required query parameter: '{}'",
218-
key
219-
));
220-
}
221-
}
222-
}
223-
224-
Ok(())
225-
}
226-
}
227-
228-
fn collect_schema_refs(content: &HashMap<String, parse::BaseContent>) -> Vec<&str> {
229-
let mut refs = Vec::new();
230-
for media_type in content.values() {
231-
if let Some(r) = &media_type.schema._ref {
232-
refs.push(r.as_str());
233-
}
234-
if let Some(one_of) = &media_type.schema.one_of {
235-
refs.extend(one_of.iter().filter_map(|s| s._ref.as_deref()));
236-
}
237-
if let Some(all_of) = &media_type.schema.all_of {
238-
refs.extend(all_of.iter().filter_map(|s| s._ref.as_deref()));
239-
}
240-
}
241-
refs
242-
}
243-
244-
fn validate_field_format(key: &str, value: &Value, format: Format) -> Result<()> {
245-
let str_val = value
246-
.as_str()
247-
.ok_or_else(|| anyhow::anyhow!("this value must string '{}'", key))?;
248-
match format {
249-
Format::Email => {
250-
if !validator::validate_email(str_val) {
251-
return Err(anyhow::anyhow!(
252-
"Invalid email format for query parameter '{}': '{}'",
253-
key,
254-
str_val
255-
));
256-
}
257-
}
258-
Format::Time => {
259-
NaiveTime::parse_from_str(str_val, "%H:%M:%S").map_err(|_| {
260-
anyhow::anyhow!(
261-
"Invalid time format for query parameter '{}': '{}'",
262-
key,
263-
str_val
264-
)
265-
})?;
266-
}
267-
Format::Date => {
268-
NaiveDate::parse_from_str(str_val, "%Y-%m-%d").map_err(|_| {
269-
anyhow::anyhow!(
270-
"Invalid RFC3339 full-date for query parameter '{}': '{}'",
271-
key,
272-
str_val
273-
)
274-
})?;
275-
}
276-
Format::DateTime => {
277-
DateTime::parse_from_rfc3339(str_val).map_err(|_| {
278-
anyhow::anyhow!(
279-
"Invalid datetime format for query parameter '{}': '{}'",
280-
key,
281-
str_val
282-
)
283-
})?;
284-
}
285-
Format::UUID => {
286-
uuid::Uuid::parse_str(str_val).map_err(|_| {
287-
anyhow::anyhow!(
288-
"Invalid UUID format for query parameter '{}': '{}'",
289-
key,
290-
str_val
291-
)
292-
})?;
293-
}
294-
Format::IPV4 => {
295-
str_val.parse::<Ipv4Addr>().map_err(|_| {
296-
anyhow::anyhow!(
297-
"Invalid IPv4 format for query parameter '{}': '{}'",
298-
key,
299-
str_val
300-
)
301-
})?;
302-
}
303-
Format::IPV6 => {
304-
str_val.parse::<Ipv6Addr>().map_err(|_| {
305-
anyhow::anyhow!(
306-
"Invalid IPv6 format for query parameter '{}': '{}'",
307-
key,
308-
str_val
309-
)
310-
})?;
311-
}
312-
_ => {
313-
return Err(anyhow::anyhow!(
314-
"Unsupported format '{:?}' for query parameter '{}'",
315-
format,
316-
key
317-
));
318-
}
319-
}
320-
Ok(())
321-
}
322-
323-
fn validate_schema_properties(
324-
request_fields: &HashMap<String, Value>,
325-
properties: &HashMap<String, parse::Properties>,
326-
) -> Result<()> {
327-
for (key, prop) in properties {
328-
if let Some(value) = request_fields.get(key) {
329-
validate_field_format(key, value, prop.format.clone())?;
330-
}
68+
let request_fields: HashMap<String, Value> = serde_json::from_slice(body)?;
69+
request::validator::common_body(self.path.as_str(), request_fields, open_api)
33170
}
332-
Ok(())
33371
}

0 commit comments

Comments
 (0)