diff --git a/app.py b/app.py index aa1b978..709717b 100644 --- a/app.py +++ b/app.py @@ -130,6 +130,54 @@ def verify_captcha(user_input): conn.close() return False +def validate_name(name, max_length=64): + """ + 校验名称是否符合规范 + 规则: + 1. 长度1-64个字符 + 2. 只能包含字母、数字、中文、下划线、短横线 + 3. 不能以短横线开头或结尾 + """ + if not name or len(name) > max_length: + return False + + # 允许中文、字母、数字、下划线、短横线 + pattern = r'^[a-zA-Z0-9_\-\u4e00-\u9fa5]+$' + if not re.match(pattern, name): + return False + + # 不能以短横线开头或结尾 + if name.startswith('-') or name.endswith('-'): + return False + + return True + + +def validate_common_name(cn): + """ + 校验通用名(Common Name)是否符合规范 + 规则: + 1. 长度1-64个字符 + 2. 只能包含字母、数字、点号(.)和短横线(-) + 3. 不能以点号或短横线开头或结尾 + 4. 不能连续两个点号或短横线 + """ + if not cn or len(cn) > 64: + return False + + # 只允许字母、数字、点号和短横线 + if not re.match(r'^[a-zA-Z0-9.-]+$', cn): + return False + + # 不能以点号或短横线开头或结尾 + if cn.startswith('.') or cn.endswith('.') or cn.startswith('-') or cn.endswith('-'): + return False + + # 不能连续两个点号或短横线 + if '..' in cn or '--' in cn: + return False + + return True def create_ca(ca_name, common_name, organization, organizational_unit, country, state, locality, key_size, days_valid, created_by): @@ -625,16 +673,43 @@ def ca_list(): @login_required def create_ca_view(): if request.method == 'POST': - ca_name = request.form['ca_name'] - common_name = request.form['common_name'] - organization = request.form['organization'] - organizational_unit = request.form['organizational_unit'] - country = request.form['country'] - state = request.form['state'] - locality = request.form['locality'] + ca_name = request.form['ca_name'].strip() + common_name = request.form['common_name'].strip() + organization = request.form['organization'].strip() + organizational_unit = request.form['organizational_unit'].strip() + country = request.form['country'].strip() + state = request.form['state'].strip() + locality = request.form['locality'].strip() key_size = int(request.form['key_size']) days_valid = int(request.form['days_valid']) + # 名称校验 + if not validate_name(ca_name): + flash('CA名称无效:只能包含中文、字母、数字、下划线和短横线,且不能以短横线开头或结尾', 'danger') + return render_template('create_ca.html') + + if not validate_common_name(common_name): + flash('通用名无效:只能包含字母、数字、点号和短横线,且不能以点号或短横线开头或结尾', 'danger') + return render_template('create_ca.html') + + # 检查CA名称是否已存在 + conn = get_db_connection() + if conn: + try: + cursor = conn.cursor(dictionary=True) + cursor.execute("SELECT id FROM certificate_authorities WHERE name = %s", (ca_name,)) + if cursor.fetchone(): + flash('CA名称已存在,请使用其他名称', 'danger') + return render_template('create_ca.html') + except Error as e: + print(f"Database error: {e}") + flash('检查CA名称失败', 'danger') + return render_template('create_ca.html') + finally: + if conn.is_connected(): + cursor.close() + conn.close() + ca_id = create_ca(ca_name, common_name, organization, organizational_unit, country, state, locality, key_size, days_valid, current_user.id) @@ -646,7 +721,6 @@ def create_ca_view(): return render_template('create_ca.html') - from datetime import timedelta # 确保顶部已导入 @app.route('/cas/') @@ -821,18 +895,58 @@ def certificate_list(): @login_required def create_certificate_view(): if request.method == 'POST': - common_name = request.form['common_name'] - san_dns = request.form.get('san_dns', '') - san_ip = request.form.get('san_ip', '') - organization = request.form['organization'] - organizational_unit = request.form['organizational_unit'] - country = request.form['country'] - state = request.form['state'] - locality = request.form['locality'] + common_name = request.form['common_name'].strip() + san_dns = request.form.get('san_dns', '').strip() + san_ip = request.form.get('san_ip', '').strip() + organization = request.form['organization'].strip() + organizational_unit = request.form['organizational_unit'].strip() + country = request.form['country'].strip() + state = request.form['state'].strip() + locality = request.form['locality'].strip() key_size = int(request.form['key_size']) days_valid = int(request.form['days_valid']) ca_id = int(request.form['ca_id']) + # 通用名校验 + if not validate_common_name(common_name): + flash('通用名无效:只能包含字母、数字、点号和短横线,且不能以点号或短横线开头或结尾', 'danger') + return redirect(url_for('create_certificate_view')) + + # SAN DNS校验 + if san_dns: + for dns in san_dns.split(','): + dns = dns.strip() + if not validate_common_name(dns): + flash(f'DNS SAN条目无效: {dns},只能包含字母、数字、点号和短横线', 'danger') + return redirect(url_for('create_certificate_view')) + + # SAN IP校验 + if san_ip: + for ip in san_ip.split(','): + ip = ip.strip() + if not re.match(r'^\d{1,3}\.\d{1,3}\.\d{1,3}\.\d{1,3}$', ip): + flash(f'IP SAN条目无效: {ip},请输入有效的IPv4地址', 'danger') + return redirect(url_for('create_certificate_view')) + + # 检查证书是否已存在 + conn = get_db_connection() + if conn: + try: + cursor = conn.cursor(dictionary=True) + cursor.execute("SELECT id FROM certificates WHERE common_name = %s AND ca_id = %s", + (common_name, ca_id)) + if cursor.fetchone(): + flash('该CA下已存在相同通用名的证书', 'danger') + return redirect(url_for('create_certificate_view')) + except Error as e: + print(f"Database error: {e}") + flash('检查证书名称失败', 'danger') + return redirect(url_for('create_certificate_view')) + finally: + if conn.is_connected(): + cursor.close() + conn.close() + cert_id = create_certificate(ca_id, common_name, san_dns, san_ip, organization, organizational_unit, country, state, locality, key_size, days_valid, current_user.id) @@ -867,7 +981,6 @@ def create_certificate_view(): conn.close() return redirect(url_for('certificate_list')) - @app.route('/certificates/') @login_required def certificate_detail(cert_id): diff --git a/templates/create_ca.html b/templates/create_ca.html index 7cc35cc..9a56b0c 100644 --- a/templates/create_ca.html +++ b/templates/create_ca.html @@ -12,12 +12,16 @@
- +
CA机构的显示名称
- +
证书的Common Name字段
diff --git a/templates/create_certificate.html b/templates/create_certificate.html index 2aca176..f2c8d9f 100644 --- a/templates/create_certificate.html +++ b/templates/create_certificate.html @@ -12,7 +12,9 @@
- +
证书的Common Name字段,通常是域名
@@ -56,12 +58,16 @@
- +
多个DNS用逗号分隔,如: example.com,www.example.com
- +
多个IP用逗号分隔,如: 192.168.1.1,10.0.0.1