diff --git a/fs/cifs/dns_resolve.c b/fs/cifs/dns_resolve.c index a2e0673e1b08..1e0c1bd8f2e4 100644 --- a/fs/cifs/dns_resolve.c +++ b/fs/cifs/dns_resolve.c @@ -29,45 +29,13 @@ #include "cifsproto.h" #include "cifs_debug.h" -static int dns_resolver_instantiate(struct key *key, const void *data, - size_t datalen) -{ - int rc = 0; - char *ip; - - ip = kmalloc(datalen+1, GFP_KERNEL); - if (!ip) - return -ENOMEM; - - memcpy(ip, data, datalen); - ip[datalen] = '\0'; - - rcu_assign_pointer(key->payload.data, ip); - - return rc; -} - -static void -dns_resolver_destroy(struct key *key) -{ - kfree(key->payload.data); -} - -struct key_type key_type_dns_resolver = { - .name = "dns_resolver", - .def_datalen = sizeof(struct in_addr), - .describe = user_describe, - .instantiate = dns_resolver_instantiate, - .destroy = dns_resolver_destroy, - .match = user_match, -}; - /* Checks if supplied name is IP address * returns: * 1 - name is IP * 0 - name is not IP */ -static int is_ip(const char *name) +static int +is_ip(const char *name) { int rc; struct sockaddr_in sin_server; @@ -89,6 +57,47 @@ static int is_ip(const char *name) return 0; } +static int +dns_resolver_instantiate(struct key *key, const void *data, + size_t datalen) +{ + int rc = 0; + char *ip; + + ip = kmalloc(datalen + 1, GFP_KERNEL); + if (!ip) + return -ENOMEM; + + memcpy(ip, data, datalen); + ip[datalen] = '\0'; + + /* make sure this looks like an address */ + if (!is_ip((const char *) ip)) { + kfree(ip); + return -EINVAL; + } + + key->type_data.x[0] = datalen; + rcu_assign_pointer(key->payload.data, ip); + + return rc; +} + +static void +dns_resolver_destroy(struct key *key) +{ + kfree(key->payload.data); +} + +struct key_type key_type_dns_resolver = { + .name = "dns_resolver", + .def_datalen = sizeof(struct in_addr), + .describe = user_describe, + .instantiate = dns_resolver_instantiate, + .destroy = dns_resolver_destroy, + .match = user_match, +}; + /* Resolves server name to ip address. * input: * unc - server UNC @@ -140,6 +149,7 @@ dns_resolve_server_name_to_ip(const char *unc, char **ip_addr) rkey = request_key(&key_type_dns_resolver, name, ""); if (!IS_ERR(rkey)) { + len = rkey->type_data.x[0]; data = rkey->payload.data; } else { cERROR(1, ("%s: unable to resolve: %s", __func__, name)); @@ -148,11 +158,9 @@ dns_resolve_server_name_to_ip(const char *unc, char **ip_addr) skip_upcall: if (data) { - len = strlen(data); - *ip_addr = kmalloc(len+1, GFP_KERNEL); + *ip_addr = kmalloc(len + 1, GFP_KERNEL); if (*ip_addr) { - memcpy(*ip_addr, data, len); - (*ip_addr)[len] = '\0'; + memcpy(*ip_addr, data, len + 1); if (!IS_ERR(rkey)) cFYI(1, ("%s: resolved: %s to %s", __func__, name,