Skip to content

Commit

Permalink
apache#13441 [Clojure] Add Spec Validations for the Random namespace
Browse files Browse the repository at this point in the history
  • Loading branch information
hellonico committed Dec 4, 2018
1 parent f2dcd7c commit 2d96340
Show file tree
Hide file tree
Showing 4 changed files with 45 additions and 7 deletions.
Original file line number Diff line number Diff line change
Expand Up @@ -28,7 +28,7 @@
(s/def ::momentum float?)
(s/def ::wd float?)
(s/def ::clip-gradient float?)
(s/def ::lr-scheduler #(instance? FactorScheduler))
(s/def ::lr-scheduler #(instance? FactorScheduler %))
(s/def ::sgd-opts (s/keys :opt-un [::learning-rate ::momentum ::wd ::clip-gradient ::lr-scheduler]))

(defn sgd
Expand Down
31 changes: 28 additions & 3 deletions contrib/clojure-package/src/org/apache/clojure_mxnet/random.clj
Original file line number Diff line number Diff line change
Expand Up @@ -16,8 +16,19 @@
;;

(ns org.apache.clojure-mxnet.random
(:require [org.apache.clojure-mxnet.shape :as mx-shape])
(:import (org.apache.mxnet Random)))
(:require
[org.apache.clojure-mxnet.shape :as mx-shape]
[org.apache.clojure-mxnet.context :as context]
[clojure.spec.alpha :as s]
[org.apache.clojure-mxnet.util :as util])
(:import (org.apache.mxnet Context Random)))

(s/def ::int-or-float (s/or :f float? :i int?))
(s/def ::low ::int-or-float)
(s/def ::high ::int-or-float)
(s/def ::shape-vec (s/coll-of pos-int? :kind vector?))
(s/def ::ctx #(instance? Context %))
(s/def ::uniform-opts (s/keys :opt-un [::ctx]))

(defn uniform
"Generate uniform distribution in [low, high) with shape.
Expand All @@ -29,10 +40,18 @@
out: Output place holder}
returns: The result ndarray with generated result./"
([low high shape-vec {:keys [ctx out] :as opts}]
(util/validate! ::uniform-opts opts "Incorrect random uniform parameters")
(util/validate! ::low low "Incorrect random uniform parameter")
(util/validate! ::high high "Incorrect random uniform parameters")
(util/validate! ::shape-vec shape-vec "Incorrect random uniform parameters")
(Random/uniform (float low) (float high) (mx-shape/->shape shape-vec) ctx out))
([low high shape-vec]
(uniform low high shape-vec {})))

(s/def ::loc ::int-or-float)
(s/def ::scale ::int-or-float)
(s/def ::normal-opts (s/keys :opt-un [::ctx]))

(defn normal
"Generate normal(Gaussian) distribution N(mean, stdvar^^2) with shape.
loc: The standard deviation of the normal distribution
Expand All @@ -43,10 +62,15 @@
out: Output place holder}
returns: The result ndarray with generated result./"
([loc scale shape-vec {:keys [ctx out] :as opts}]
(util/validate! ::normal-opts opts "Incorrect random normal parameters")
(util/validate! ::loc loc "Incorrect random normal parameters")
(util/validate! ::scale scale "Incorrect random normal parameters")
(util/validate! ::shape-vec shape-vec "Incorrect random uniform parameters")
(Random/normal (float loc) (float scale) (mx-shape/->shape shape-vec) ctx out))
([loc scale shape-vec]
(normal loc scale shape-vec {})))

(s/def ::seed-state ::int-or-float)
(defn seed
" Seed the random number generators in mxnet.
This seed will affect behavior of functions in this module,
Expand All @@ -58,4 +82,5 @@
This means if you set the same seed, the random number sequence
generated from GPU0 can be different from CPU."
[seed-state]
(Random/seed (int seed-state)))
(util/validate! ::seed-state seed-state "Incorrect seed parameters")
(Random/seed (int seed-state)))
Original file line number Diff line number Diff line change
Expand Up @@ -462,7 +462,7 @@
test (sym/transpose data)
shape-vec [3 4]
ctx (context/default-context)
arr-data (random/uniform 0 100 shape-vec ctx)
arr-data (random/uniform 0 100 shape-vec {:ctx ctx})
trans (ndarray/transpose (ndarray/copy arr-data))
exec-test (sym/bind test ctx {"data" arr-data})
out (-> (executor/forward exec-test)
Expand Down
Original file line number Diff line number Diff line change
Expand Up @@ -26,9 +26,9 @@
(let [[a b] [-10 10]
shape [100 100]
_ (random/seed 128)
un1 (random/uniform a b shape {:context ctx})
un1 (random/uniform a b shape {:ctx ctx})
_ (random/seed 128)
un2 (random/uniform a b shape {:context ctx})]
un2 (random/uniform a b shape {:ctx ctx})]
(is (= un1 un2))
(is (< (Math/abs
(/ (/ (apply + (ndarray/->vec un1))
Expand All @@ -52,3 +52,16 @@
(is (< (Math/abs (- mean mu)) 0.1))
(is (< (Math/abs (- stddev sigma)) 0.1)))))

(defn random-or-normal [fn_]
(is (thrown? Exception (fn_ 'a 2 [])))
(is (thrown? Exception (fn_ 1 'b [])))
(is (thrown? Exception (fn_ 1 2 [-1])))
(is (thrown? Exception (fn_ 1 2 [2 3 0])))
(is (thrown? Exception (fn_ 1 2 [10 10] {:ctx "a"})))
(let [ctx (context/default-context)]
(is (not (nil? (fn_ 1 1 [100 100] {:ctx ctx}))))))

(deftest test-random-parameters-specs
(random-or-normal random/normal)
(random-or-normal random/uniform)
(is (thrown? Exception (random/seed "a"))))

0 comments on commit 2d96340

Please sign in to comment.